diff --git a/.flake8 b/.flake8 index 609fa2c03..41d8799c8 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] show-source=true statistics=true -max-line-length = 80 +max-line-length = 88 per-file-ignores = # line too long icefall/diagnostics.py: E501, @@ -11,7 +11,8 @@ per-file-ignores = egs/*/ASR/*/scaling.py: E501, egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203 egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203 - egs/librispeech/ASR/conformer_ctc2/*py: E501, + egs/librispeech/ASR/conformer_ctc*/*py: E501, + egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203 egs/librispeech/ASR/RESULTS.md: E999, # invalid escape sequence (cause by tex formular), W605 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..5d65b98e9 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,3 @@ +# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890) +107df3b115a58f1b68a6458c3f94a130004be34c +d31db010371a4128856480382876acdc0d1739ed diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh index bb7c7dfdc..0bec8c0c4 100755 --- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh +++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh @@ -15,5 +15,5 @@ mkdir -p data cd data [ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank cd .. -./local/compute_fbank_librispeech.py +./local/compute_fbank_librispeech.py --dataset 'test-clean test-other' ls -lh data/fbank/ diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index e70a1848d..4c393f6be 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -25,7 +25,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh new file mode 100755 index 000000000..c68ccc954 --- /dev/null +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lm/G_4_gram.pt" +git lfs pull --include "exp/jit_trace.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Decode with models exported by torch.jit.trace()" + +for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +log "Export to torchscript model" + +./conformer_ctc3/export.py \ + --exp-dir $repo/exp \ + --lang-dir $repo/data/lang_bpe_500 \ + --jit-trace 1 \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.trace()" + +for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./conformer_ctc3/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p conformer_ctc3/exp + ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh conformer_ctc3/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in ctc-decoding 1best; do + log "Decoding with $method" + ./conformer_ctc3/decode.py \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir conformer_ctc3/exp/ \ + --max-duration $max_duration \ + --decoding-method $method \ + --lm-dir data/lm + done + + rm conformer_ctc3/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh new file mode 100755 index 000000000..4cd2c4bec --- /dev/null +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -0,0 +1,191 @@ +#!/usr/bin/env bash +# +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) +abs_repo=$(realpath $repo) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s pretrained-iter-468000-avg-16.pt pretrained.pt +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Test exporting with torch.jit.trace()" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit-trace 1 + +log "Decode with models exported by torch.jit.trace()" + +./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./lstm_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./lstm_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + +if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then + lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + log "Download pre-trained RNN-LM model from ${lm_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url + lm_repo=$(basename $lm_repo_url) + pushd $lm_repo + git lfs pull --include "exp/pretrained.pt" + mv exp/pretrained.pt exp/epoch-88.pt + popd + + mkdir -p lstm_transducer_stateless2/exp + ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other with RNN LM" + + ./lstm_transducer_stateless2/decode.py \ + --use-averaged-model 0 \ + --epoch 999 \ + --avg 1 \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam 4 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_repo/exp \ + --lm-epoch 88 \ + --lm-avg 1 \ + --lm-scale 0.3 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 +fi + +if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then + bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram + log "Download bi-gram LM from ${bigram_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url + bigramlm_repo=$(basename $bigram_repo_url) + pushd $bigramlm_repo + git lfs pull --include "2gram.fst.txt" + cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/. + popd + + lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + log "Download pre-trained RNN-LM model from ${lm_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url + lm_repo=$(basename $lm_repo_url) + pushd $lm_repo + git lfs pull --include "exp/pretrained.pt" + mv exp/pretrained.pt exp/epoch-88.pt + popd + + mkdir -p lstm_transducer_stateless2/exp + ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + + ./lstm_transducer_stateless2/decode.py \ + --use-averaged-model 0 \ + --epoch 999 \ + --avg 1 \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_LODR \ + --beam 4 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir $lm_repo/exp \ + --lm-scale 0.4 \ + --lm-epoch 88 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 +fi + +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then + mkdir -p lstm_transducer_stateless2/exp + ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./lstm_transducer_stateless2/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir lstm_transducer_stateless2/exp + done + + rm lstm_transducer_stateless2/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml deleted file mode 100755 index 6ce92d022..000000000 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env bash -# -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -soxi $repo/test_wavs/*.wav -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-iter-468000-avg-16.pt pretrained.pt -ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt -popd - -log "Install ncnn and pnnx" - -# We are using a modified ncnn here. Will try to merge it to the official repo -# of ncnn -git clone https://github.com/csukuangfj/ncnn -pushd ncnn -git submodule init -git submodule update python/pybind11 -python3 setup.py bdist_wheel -ls -lh dist/ -pip install dist/*.whl -cd tools/pnnx -mkdir build -cd build -cmake .. -make -j4 pnnx - -./src/pnnx || echo "pass" - -popd - -log "Test exporting to pnnx format" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --pnnx 1 - -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt - -./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav - -./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav - - - -log "Test exporting with torch.jit.trace()" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --jit-trace 1 - -log "Decode with models exported by torch.jit.trace()" - -./lstm_transducer_stateless2/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ - --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ - --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Test exporting to ONNX" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --onnx 1 - -log "Decode with ONNX models " - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1221-135766-0001.wav - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1221-135766-0002.wav - - - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./lstm_transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./lstm_transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" - -if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then - lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm - log "Download pre-trained RNN-LM model from ${lm_repo_url}" - git clone $lm_repo_url - lm_repo=$(basename $lm_repo_url) - pushd $lm_repo - git lfs pull --include "exp/pretrained.pt" - cd exp - ln -s pretrained.pt epoch-88.pt - popd - - ./lstm_transducer_stateless2/decode.py \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ - --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir $lm_repo/exp \ - --rnn-lm-epoch 88 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 -fi - -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then - mkdir -p lstm_transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh lstm_transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./lstm_transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir lstm_transducer_stateless2/exp - done - - rm lstm_transducer_stateless2/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index dafea56db..6792c7088 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index c3d07dc0e..dbf678d72 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -23,7 +23,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 22de3b45d..b6d477afe 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -22,7 +22,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index 880767443..efa4b53f0 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -27,14 +26,6 @@ ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt popd -log "Test exporting to ONNX format" - -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --onnx 1 log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ @@ -51,30 +42,8 @@ log "Export to torchscript model" --avg 1 \ --jit-trace 1 -ls -lh $repo/exp/*.onnx ls -lh $repo/exp/*.pt -log "Decode with ONNX models" - -./pruned_transducer_stateless3/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx - -./pruned_transducer_stateless3/onnx_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - log "Decode with models exported by torch.jit.trace()" ./pruned_transducer_stateless3/jit_pretrained.py \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index c6a781318..511fe0c9e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 75861bbc7..2bc179c86 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -33,6 +32,7 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ + --use-averaged-model false \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh new file mode 100755 index 000000000..192438353 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lm/G_4_gram.pt" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_ctc/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_ctc/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_ctc/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_ctc/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless7_ctc/exp + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 999 \ + --avg 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration $max_duration \ + --use-averaged-model 0 \ + --decoding-method $m \ + --hlg-scale 0.6 \ + --lm-dir data/lm + done + + rm pruned_transducer_stateless7_ctc/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh new file mode 100755 index 000000000..761eb72e2 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -0,0 +1,147 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_ctc_bs/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_ctc_bs/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_ctc_bs/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 999 \ + --avg 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration $max_duration \ + --use-averaged-model 0 \ + --decoding-method $m \ + --hlg-scale 0.6 + done + + rm pruned_transducer_stateless7_ctc_bs/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh new file mode 100755 index 000000000..e1e4e1f10 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "exp/encoder_jit_trace.pt" +git lfs pull --include "exp/decoder_jit_trace.pt" +git lfs pull --include "exp/joiner_jit_trace.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Export to torchscript model by torch.jit.trace()" +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 + +log "Decode with models exported by torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_streaming/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_streaming/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + num_decode_stream=200 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "decoding with $method" + + ./pruned_transducer_stateless7_streaming/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp + done + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --decode-chunk-len 32 \ + --num-decode-streams $num_decode_stream + --exp-dir pruned_transducer_stateless7_streaming/exp + done + + rm pruned_transducer_stateless7_streaming/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh new file mode 100755 index 000000000..5d9485692 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Export to torchscript model" +./pruned_transducer_stateless8/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model false \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless8/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless8/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless8/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless8/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless8/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless8/exp + done + + rm pruned_transducer_stateless8/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index af37102d5..77cd59506 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index 5b8ed396b..b4aca1b6b 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh new file mode 100755 index 000000000..a58b8ec56 --- /dev/null +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/lang_bpe_500/3gram.pt" +git lfs pull --include "data/lang_bpe_500/4gram.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./zipformer_mmi/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./zipformer_mmi/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --lang-dir $repo/data/lang_bpe_500 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do + log "$method" + + ./zipformer_mmi/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_bpe_500 \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p zipformer_mmi/exp + ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh zipformer_mmi/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do + log "Decoding with $method" + + ./zipformer_mmi/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --nbest-scale 1.2 \ + --hp-scale 1.0 \ + --max-duration $max_duration \ + --lang-dir $repo/data/lang_bpe_500 \ + --exp-dir zipformer_mmi/exp + done + + rm zipformer_mmi/exp/*.pt +fi diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 96c320616..125d1f3b1 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.flac ls -lh $repo/test_wavs/*.flac log "CTC decoding" diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 209d4814f..89115e88d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 34ff76fe4..85e2c89e6 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh index 75650c2d3..0644d9be0 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh index bcc2d74cb..79fb64311 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index d3e40315a..41456f11b 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index cfa006776..1331c966c 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav log "Beam search decoding" diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 2d237dcf2..90097c752 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -20,7 +20,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh new file mode 100755 index 000000000..52491d2ea --- /dev/null +++ b/.github/scripts/test-ncnn-export.sh @@ -0,0 +1,234 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +pushd egs/librispeech/ASR + +log "Install ncnn and pnnx" + +# We are using a modified ncnn here. Will try to merge it to the official repo +# of ncnn +git clone https://github.com/csukuangfj/ncnn +pushd ncnn +git submodule init +git submodule update python/pybind11 +python3 setup.py bdist_wheel +ls -lh dist/ +pip install dist/*.whl +cd tools/pnnx +mkdir build +cd build + +echo "which python3" + +which python3 +#/opt/hostedtoolcache/Python/3.8.16/x64/bin/python3 + +cmake -D Python3_EXECUTABLE=$(which python3) .. +make -j4 pnnx + +./src/pnnx || echo "pass" + +popd + +export PATH=$PWD/ncnn/tools/pnnx/build/src:$PATH + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +python3 ./lstm_transducer_stateless2/ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/0.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh new file mode 100755 index 000000000..39467c44a --- /dev/null +++ b/.github/scripts/test-onnx-export.sh @@ -0,0 +1,351 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + + + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless3/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless3/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt" + +cd exp +ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless5/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless5/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless5/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url= + +rm -rf $repo +log "--------------------------------------------------------------------------" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless7/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Test exporting to ONNX format" + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +log "Run onnx_pretrained.py" + +./conv_emformer_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./lstm_transducer_stateless2/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit-trace 1 + +log "Test exporting to ONNX format" + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./lstm_transducer_stateless2/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./lstm_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index dd0969f51..d7fe2c964 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -26,6 +26,10 @@ on: pull_request: types: [labeled] +concurrency: + group: build_doc-${{ github.ref }} + cancel-in-progress: true + jobs: build-doc: if: github.event.label.name == 'doc' || github.event_name == 'push' diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index e46b01a08..f5ba73195 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -34,6 +34,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_aishell_2022_06_20-${{ github.ref }} + cancel-in-progress: true + jobs: run_aishell_2022_06_20: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -61,7 +65,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -83,7 +87,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index c631927fa..c7b9cc79d 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_gigaspeech_2022_05_13-${{ github.ref }} + cancel-in-progress: true + jobs: run_gigaspeech_2022_05_13: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 5df710006..9c7cd1228 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_03_12-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_03_12: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index 24c062442..78c9e759f 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_04_29-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_04_29: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index 29215ec25..04799bf52 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_05_13-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_05_13: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 3b98b500e..6dfc23920 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_11_11_zipformer-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_11_11_zipformer: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml new file mode 100644 index 000000000..0544e68b3 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -0,0 +1,159 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-11-14-stateless8 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_librispeech_2022_11_14_zipformer_stateless8: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless8 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless8/exp + + cd pruned_transducer_stateless8 + echo "results for pruned_transducer_stateless8" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless8 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14 + path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml new file mode 100644 index 000000000..62e1f2a01 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml @@ -0,0 +1,163 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-01-stateless7-ctc +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_11_11_zipformer: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_ctc/exp + + cd pruned_transducer_stateless7_ctc + echo "results for pruned_transducer_stateless7_ctc" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===ctc decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01 + path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml new file mode 100644 index 000000000..7dc33aaa9 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml @@ -0,0 +1,167 @@ +# Copyright 2022 Zengwei Yao + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-08-zipformer-mmi +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_librispeech_2022_12_08_zipformer-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_librispeech_2022_12_08_zipformer: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh + + - name: Display decoding results for librispeech zipformer-mmi + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./zipformer-mmi/exp + + cd zipformer-mmi + echo "results for zipformer-mmi" + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===nbest===" + find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===nbest-rescoring-LG===" + find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===nbest-rescoring-3-gram===" + find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===nbest-rescoring-4-gram===" + find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech zipformer-mmi + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08 + path: egs/librispeech/ASR/zipformer_mmi/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml new file mode 100644 index 000000000..de55847ad --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml @@ -0,0 +1,163 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-15-stateless7-ctc-bs +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_12_15_zipformer_ctc_bs: + if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_ctc_bs/exp + + cd pruned_transducer_stateless7_ctc_bs + echo "results for pruned_transducer_stateless7_ctc_bs" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===ctc decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15 + path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml new file mode 100644 index 000000000..feb5c6fd0 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -0,0 +1,172 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-29-stateless7-streaming +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_librispeech_2022_12_29_zipformer_streaming: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_streaming/exp + + cd pruned_transducer_stateless7_streaming + echo "results for pruned_transducer_stateless7_streaming" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming greedy search===" + find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming fast_beam_search===" + find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===streaming modified beam search===" + find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29 + path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/ diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml new file mode 100644 index 000000000..c95ed8b9a --- /dev/null +++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml @@ -0,0 +1,155 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-conformer-ctc3-2022-11-28 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_librispeech_2022_11_28_conformer_ctc3: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh + + - name: Display decoding results for librispeech conformer_ctc3 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./conformer_ctc3/exp + + cd conformer_ctc3 + echo "results for conformer_ctc3" + echo "===ctc-decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech conformer_ctc3 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28 + path: egs/librispeech/ASR/conformer_ctc3/exp/ diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index a90841fb6..e14d4e92f 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -16,9 +16,13 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -43,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -102,12 +106,12 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml + .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh - name: Display decoding results for lstm_transducer_stateless2 if: github.event_name == 'schedule' @@ -135,13 +139,25 @@ jobs: cd egs/librispeech/ASR tree lstm_transducer_stateless2/exp cd lstm_transducer_stateless2/exp - echo "===modified_beam_search_rnnlm_shallow_fusion===" - find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + echo "===modified_beam_search_lm_shallow_fusion===" + echo "===Using RNNLM===" + find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Display decoding results for lstm_transducer_stateless2 + if: github.event.label.name == 'LODR' + shell: bash + run: | + cd egs/librispeech/ASR + tree lstm_transducer_stateless2/exp + cd lstm_transducer_stateless2/exp + echo "===modified_beam_search_rnnlm_LODR===" + find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' + if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index 66a2c240b..73d91fcd4 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -33,9 +33,13 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index 55428861c..8a690393e 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_streaming_2022_06_26-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_streaming_2022_06_26: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index f520405e1..217dbdfa1 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_04_19-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_04_19: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -60,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 9bc6a481f..4e8e7b8db 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_conformer_ctc-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_conformer_ctc: if: github.event.label.name == 'ready' || github.event_name == 'push' @@ -50,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -69,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index 7a0f30b0f..ddde4f1d6 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -32,6 +32,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -59,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -118,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index 797f3fe50..00ea97b2a 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -32,6 +32,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -59,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -118,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 29e665881..b3cfc9efd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_modified_2_aishell: if: github.event.label.name == 'ready' || github.event_name == 'push' @@ -50,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -69,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index 6193f28e7..ab598541d 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_modified_aishell: if: github.event.label.name == 'ready' || github.event_name == 'push' @@ -50,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -69,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 32208076c..d663d49dd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -32,6 +32,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_pre_trained_transducer_stateless-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' @@ -59,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -118,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 965d0f655..9cb9d3b59 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_transducer-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer: if: github.event.label.name == 'ready' || github.event_name == 'push' @@ -50,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -69,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml new file mode 100644 index 000000000..f8d9c02c5 --- /dev/null +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -0,0 +1,71 @@ +name: run-ptb-rnn-lm-training + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: run_ptb_rnn_lm_training-${{ github.ref }} + cancel-in-progress: true + +jobs: + run_ptb_rnn_lm_training: + if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.8"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Prepare data + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/ptb/LM + ./prepare.sh + + - name: Run training + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/ptb/LM + ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2 + + - name: Upload pretrained models + uses: actions/upload-artifact@v2 + if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' + with: + name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb + path: egs/ptb/LM/my-rnnlm-exp/ diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index d96a3bfe6..14fb96ec8 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -23,8 +23,12 @@ on: pull_request: types: [labeled] +concurrency: + group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }} + cancel-in-progress: true + jobs: - run_librispeech_pruned_transducer_stateless3_2022_05_13: + run_wenetspeech_pruned_transducer_stateless2: if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech' runs-on: ${{ matrix.os }} strategy: @@ -50,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -72,7 +76,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index ce77c47df..83a1d5462 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -21,17 +21,21 @@ on: branches: - master pull_request: - types: [labeled] + branches: + - master + +concurrency: + group: run-yesno-recipe-${{ github.ref }} + cancel-in-progress: true jobs: run-yesno-recipe: - if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: matrix: # os: [ubuntu-18.04, macos-10.15] # TODO: enable macOS for CPU testing - os: [ubuntu-18.04] + os: [ubuntu-latest] python-version: [3.8] fail-fast: false @@ -61,9 +65,9 @@ jobs: - name: Install Python dependencies run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Run yesno recipe shell: bash diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 90459bc1c..fc1dcbfd4 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -24,6 +24,10 @@ on: branches: - master +concurrency: + group: style_check-${{ github.ref }} + cancel-in-progress: true + jobs: style_check: runs-on: ${{ matrix.os }} @@ -45,17 +49,18 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 - # See https://github.com/psf/black/issues/2964 - # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 + python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 + # Click issue fixed in https://github.com/psf/black/pull/2966 - name: Run flake8 shell: bash working-directory: ${{github.workspace}} run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --show-source --statistics - flake8 . + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 - name: Run black shell: bash diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml new file mode 100644 index 000000000..cdea54854 --- /dev/null +++ b/.github/workflows/test-ncnn-export.yml @@ -0,0 +1,75 @@ +name: test-ncnn-export + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: test_ncnn_export-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_ncnn_export: + if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Test ncnn export + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/test-ncnn-export.sh diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml new file mode 100644 index 000000000..3dc4261ab --- /dev/null +++ b/.github/workflows/test-onnx-export.yml @@ -0,0 +1,75 @@ +name: test-onnx-export + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: test_onnx_export-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_onnx_export: + if: github.event.label.name == 'ready' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Test ONNX export + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/test-onnx-export.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 04fc0265f..079772e97 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,26 +21,23 @@ on: branches: - master pull_request: - types: [labeled] + branches: + - master + +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true jobs: test: - if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: matrix: - # os: [ubuntu-18.04, macos-10.15] - # disable macOS test for now. - os: [ubuntu-18.04] - python-version: [3.7, 3.8] - torch: ["1.8.0", "1.11.0"] - torchaudio: ["0.8.0", "0.11.0"] - k2-version: ["1.15.1.dev20220427"] - exclude: - - torch: "1.8.0" - torchaudio: "0.11.0" - - torch: "1.11.0" - torchaudio: "0.8.0" + os: [ubuntu-latest] + python-version: ["3.8"] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.23.2.dev20221201"] fail-fast: false @@ -59,7 +56,7 @@ jobs: run: | sudo apt update sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all - name: Install Python dependencies run: | @@ -67,21 +64,16 @@ jobs: # numpy 1.20.x does not support python 3.6 pip install numpy==1.19 pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then - pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - else - pip install torchaudio==${{ matrix.torchaudio }} - fi + pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* pip install kaldifst pip install onnxruntime - pip install -r requirements.txt - name: Install graphviz @@ -121,19 +113,20 @@ jobs: cd ../pruned_transducer_stateless4 pytest -v -s + cd ../pruned_transducer_stateless7 + pytest -v -s + cd ../transducer_stateless pytest -v -s - if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then - cd ../transducer - pytest -v -s + # cd ../transducer + # pytest -v -s - cd ../transducer_stateless2 - pytest -v -s + cd ../transducer_stateless2 + pytest -v -s - cd ../transducer_lstm - pytest -v -s - fi + cd ../transducer_lstm + pytest -v -s - name: Run tests if: startsWith(matrix.os, 'macos') @@ -164,13 +157,11 @@ jobs: cd ../transducer_stateless pytest -v -s - if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then - cd ../transducer - pytest -v -s + # cd ../transducer + # pytest -v -s - cd ../transducer_stateless2 - pytest -v -s + cd ../transducer_stateless2 + pytest -v -s - cd ../transducer_lstm - pytest -v -s - fi + cd ../transducer_lstm + pytest -v -s diff --git a/.gitignore b/.gitignore index 406deff6a..8af05d884 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,26 @@ log *.bak *-bak *bak.py + +# Ignore Mac system files +.DS_store + +# Ignore node_modules folder +node_modules + +# ignore .nfs + +.nfs* + +# Ignore all text files +*.txt + +# Ignore files related to API keys +.env + +# Ignore SASS config files +.sass-cache + *.param *.bin +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 446ba0fe7..5cb213327 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,38 @@ repos: - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black - args: [--line-length=80] - additional_dependencies: ['click==8.0.1'] + args: ["--line-length=88"] + additional_dependencies: ['click==8.1.0'] exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 5.0.4 hooks: - id: flake8 - args: [--max-line-length=80] + args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] + + # What are we ignoring here? + # E203: whitespace before ':' + # E266: too many leading '#' for block comment + # E501: line too long + # F401: module imported but unused + # E402: module level import not at top of file + # F403: 'from module import *' used; unable to detect undefined names + # F841: local variable is assigned to but never used + # W503: line break before binary operator + # In addition, the default ignore list is: + # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.9.2 + rev: 5.10.1 hooks: - id: isort - args: [--profile=black, --line-length=80] + args: ["--profile=black"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: check-executables-have-shebangs - id: end-of-file-fixer diff --git a/LICENSE b/LICENSE index ee06cfc77..d64569567 100644 --- a/LICENSE +++ b/LICENSE @@ -1,13 +1,4 @@ - Legal Notices - - NOTE (this is not from the Apache License): The copyright model is that - authors (or their employers, if noted in individual files) own their - individual contributions. The authors' contributions can be discerned - from the git history. - - ------------------------------------------------------------------------- - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ diff --git a/docker/README.md b/docker/README.md index 6f2314e96..c14b9bf75 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,7 +2,7 @@ 2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. -If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. +If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. @@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with ```bash $ nvidia-smi -Tue Sep 20 00:26:13 2022 +Tue Sep 20 00:26:13 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ @@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022 | 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ - + +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | @@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022 ``` ## Building images locally -If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. -For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. +If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. +For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. ```dockerfile ENV http_proxy=http://aaa.bb.cc.net:8080 \ https_proxy=http://aaa.bb.cc.net:8080 ``` -Then, proceed with these commands. +Then, proceed with these commands. ### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: @@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall ``` ### Tips: -1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`. 2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. -Overall, your docker run command should look like this. +Overall, your docker run command should look like this. ```bash docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 @@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re ### Linking to icefall in your host machine -If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. +If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. -Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. +Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. Use these commands once you are inside the container. @@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall docker exec -it icefall /bin/bash ``` -## Restarting a killed container that has been run before. +## Restarting a killed container that has been run before. ```bash docker start -ai icefall ``` @@ -111,4 +111,4 @@ docker start -ai icefall ## Sample usage of the CPU based images: ```bash docker run -it icefall /bin/bash -``` +``` diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile index 524303fb8..ff9e40604 100644 --- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 # install normal source RUN apt-get update && \ @@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.12 && \ pip install graphviz - + #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ @@ -68,6 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ pip install -r requirements.txt +RUN pip install kaldifeat ENV PYTHONPATH /workspace/icefall:$PYTHONPATH -WORKDIR /workspace/icefall \ No newline at end of file +WORKDIR /workspace/icefall diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 17a8215f9..5c7423fa5 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -1,12 +1,12 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel # ENV http_proxy=http://aaa.bbb.cc.net:8080 \ -# https_proxy=http://aaa.bbb.cc.net:8080 +# https_proxy=http://aaa.bbb.cc.net:8080 RUN rm /etc/apt/sources.list.d/cuda.list && \ rm /etc/apt/sources.list.d/nvidia-ml.list && \ apt-key del 7fa2af80 - + # install normal source RUN apt-get update && \ apt-get install -y --no-install-recommends \ @@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18 curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ - rm -rf /var/lib/apt/lists/* && \ + rm -rf /var/lib/apt/lists/* && \ mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ @@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -# flac + +# flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ - cd /opt && \ + cd /opt && \ xz -d flac-1.3.2.tar.xz && \ tar -xvf flac-1.3.2.tar && \ cd flac-1.3.2 && \ @@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && make && make install && \ rm -rf flac-1.3.2.tar && \ find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ - cd - + cd - RUN conda install -y -c pytorch torchaudio=0.7.1 && \ pip install graphviz @@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ cd - # install lhotse -RUN pip install git+https://github.com/lhotse-speech/lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ cd /workspace/icefall && \ @@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ ENV PYTHONPATH /workspace/icefall:$PYTHONPATH WORKDIR /workspace/icefall - diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..3abb38f8b --- /dev/null +++ b/docs/README.md @@ -0,0 +1,24 @@ + +## Usage + +```bash +cd /path/to/icefall/docs +pip install -r requirements.txt +make clean +make html +cd build/html +python3 -m http.server 8000 +``` + +It prints: + +``` +Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ... +``` + +Open your browser and go to to view the generated +documentation. + +Done! + +**Hint**: You can change the port number when starting the server. diff --git a/docs/source/conf.py b/docs/source/conf.py index 221d9d734..6901dec02 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -78,3 +78,15 @@ html_context = { } todo_include_todos = True + +rst_epilog = """ +.. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _sherpa-onnx: https://github.com/k2-fsa/sherpa-onnx +.. _icefall: https://github.com/k2-fsa/icefall +.. _git-lfs: https://git-lfs.com/ +.. _ncnn: https://github.com/tencent/ncnn +.. _LibriSpeech: https://www.openslr.org/12 +.. _musan: http://www.openslr.org/17/ +.. _ONNX: https://github.com/onnx/onnx +.. _onnxruntime: https://github.com/microsoft/onnxruntime +""" diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst index 7d61a3ba1..3baaaeec2 100644 --- a/docs/source/contributing/code-style.rst +++ b/docs/source/contributing/code-style.rst @@ -11,9 +11,9 @@ We use the following tools to make the code style to be as consistent as possibl The following versions of the above tools are used: - - ``black == 12.6b0`` - - ``flake8 == 3.9.2`` - - ``isort == 5.9.2`` + - ``black == 22.3.0`` + - ``flake8 == 5.0.4`` + - ``isort == 5.10.1`` After running the following commands: @@ -54,10 +54,17 @@ it should succeed this time: If you want to check the style of your code before ``git commit``, you can do the following: + .. code-block:: bash + + $ pre-commit install + $ pre-commit run + +Or without installing the pre-commit hooks: + .. code-block:: bash $ cd icefall - $ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2 + $ pip install black==22.3.0 flake8==5.0.4 isort==5.10.1 $ black --check your_changed_file.py $ black your_changed_file.py # modify it in-place $ diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst new file mode 100644 index 000000000..72b0302d7 --- /dev/null +++ b/docs/source/faqs.rst @@ -0,0 +1,107 @@ +Frequently Asked Questions (FAQs) +================================= + +In this section, we collect issues reported by users and post the corresponding +solutions. + + +OSError: libtorch_hip.so: cannot open shared object file: no such file or directory +----------------------------------------------------------------------------------- + +One user is using the following code to install ``torch`` and ``torchaudio``: + +.. code-block:: bash + + pip install \ + torch==1.10.0+cu111 \ + torchvision==0.11.0+cu111 \ + torchaudio==0.10.0 \ + -f https://download.pytorch.org/whl/torch_stable.html + +and it throws the following error when running ``tdnn/train.py``: + +.. code-block:: + + OSError: libtorch_hip.so: cannot open shared object file: no such file or directory + +The fix is to specify the CUDA version while installing ``torchaudio``. That +is, change ``torchaudio==0.10.0`` to ``torchaudio==0.10.0+cu11```. Therefore, +the correct command is: + +.. code-block:: bash + + pip install \ + torch==1.10.0+cu111 \ + torchvision==0.11.0+cu111 \ + torchaudio==0.10.0+cu111 \ + -f https://download.pytorch.org/whl/torch_stable.html + +AttributeError: module 'distutils' has no attribute 'version' +------------------------------------------------------------- + +The error log is: + +.. code-block:: + + Traceback (most recent call last): + File "./tdnn/train.py", line 14, in + from asr_datamodule import YesNoAsrDataModule + File "/home/xxx/code/next-gen-kaldi/icefall/egs/yesno/ASR/tdnn/asr_datamodule.py", line 34, in + from icefall.dataset.datamodule import DataModule + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/__init__.py", line 3, in + from . import ( + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/decode.py", line 23, in + from icefall.utils import add_eos, add_sos, get_texts + File "/home/xxx/code/next-gen-kaldi/icefall/icefall/utils.py", line 39, in + from torch.utils.tensorboard import SummaryWriter + File "/home/xxx/tool/miniconda3/envs/yyy/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py", line 4, in + LooseVersion = distutils.version.LooseVersion + AttributeError: module 'distutils' has no attribute 'version' + +The fix is: + +.. code-block:: bash + + pip uninstall setuptools + + pip install setuptools==58.0.4 + +ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory +-------------------------------------------------------------------------------------------- + +If you are using ``conda`` and encounter the following issue: + +.. code-block:: + + Traceback (most recent call last): + File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 24, in + from _k2 import DeterminizeWeightPushingType + ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory + + During handling of the above exception, another exception occurred: + + Traceback (most recent call last): + File "/k2-dev/yangyifan/icefall/egs/librispeech/ASR/./pruned_transducer_stateless7_ctc_bs/decode.py", line 104, in + import k2 + File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 30, in + raise ImportError( + ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory + Note: If you're using anaconda and importing k2 on MacOS, + you can probably fix this by setting the environment variable: + export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.10/site-packages:$DYLD_LIBRARY_PATH + +Please first try to find where ``libpython3.10.so.1.0`` locates. + +For instance, + +.. code-block:: bash + + cd $CONDA_PREFIX/lib + find . -name "libpython*" + +If you are able to find it inside ``$CODNA_PREFIX/lib``, please set the +following environment variable: + +.. code-block:: bash + + export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH diff --git a/docs/source/index.rst b/docs/source/index.rst index be9977ca9..8d76eb68b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,7 +21,16 @@ speech recognition recipes using `k2 `_. :caption: Contents: installation/index + faqs model-export/index + +.. toctree:: + :maxdepth: 3 + recipes/index + +.. toctree:: + :maxdepth: 2 + contributing/index huggingface/index diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg index 534b2e534..3019ff03d 100644 --- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg +++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg @@ -1 +1 @@ -k2: >= v1.9k2>= v1.9 \ No newline at end of file +k2: >= v1.9k2>= v1.9 diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg index 4254dc58a..df677ad09 100644 --- a/docs/source/installation/images/python-gt-v3.6-blue.svg +++ b/docs/source/installation/images/python-gt-v3.6-blue.svg @@ -1 +1 @@ -python: >= 3.6python>= 3.6 \ No newline at end of file +python: >= 3.6python>= 3.6 diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg index d3ece9a17..d7007d742 100644 --- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg +++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg @@ -1 +1 @@ -torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file +torch: >= 1.6.0torch>= 1.6.0 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index c4474c3d9..5b9fb2664 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -393,6 +393,17 @@ Now let us run the training part: We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU even if there are GPUs available. +.. hint:: + + In case you get a ``Segmentation fault (core dump)`` error, please use: + + .. code-block:: bash + + export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + + See more at `` if you are + interested. + The training log is given below: .. code-block:: diff --git a/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..ecbdd4b31 --- /dev/null +++ b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt @@ -0,0 +1,21 @@ +2023-01-11 12:15:38,677 INFO [export-for-ncnn.py:220] device: cpu +2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:229] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_v +alid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampl +ing_factor': 4, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.2', 'k2-build-type': +'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'a34171ed85605b0926eebbd0463d059431f4f74a', 'k2-git-date': 'Wed Dec 14 00:06:38 2022', + 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-vers +ion': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'fix-stateless3-train-2022-12-27', 'icefall-git-sha1': '530e8a1-dirty', ' +icefall-git-date': 'Tue Dec 27 13:59:18 2022', 'icefall-path': '/star-fj/fangjun/open-source/icefall', 'k2-path': '/star-fj/fangjun/op +en-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279 +-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '127.0.0.1'}, 'epoch': 30, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefa +ll-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp'), 'bpe_model': './icefall-asr-librispeech-conv-emformer-transdu +cer-stateless2-2022-07-05//data/lang_bpe_500/bpe.model', 'jit': False, 'context_size': 2, 'use_averaged_model': False, 'encoder_dim': +512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'cnn_module_kernel': 31, 'left_context_length': 32, 'chunk_length' +: 32, 'right_context_length': 8, 'memory_size': 32, 'blank_id': 0, 'vocab_size': 500} +2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:231] About to create model +2023-01-11 12:15:40,053 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-conv-emformer-transducer-stateless2-2 +022-07-05/exp/epoch-30.pt +2023-01-11 12:15:40,708 INFO [export-for-ncnn.py:315] Number of model parameters: 75490012 +2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:318] Using torch.jit.trace() +2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:320] Exporting encoder +2023-01-11 12:15:41,682 INFO [export-for-ncnn.py:149] chunk_length: 32, right_context_length: 8 diff --git a/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..fe4460985 --- /dev/null +++ b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt @@ -0,0 +1,18 @@ +2023-02-17 11:22:42,862 INFO [export-for-ncnn.py:222] device: cpu +2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:231] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'dim_feedforward': 2048, 'decoder_dim': 512, 'joiner_dim': 512, 'is_pnnx': False, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-dirty', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp'), 'bpe_model': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': 12, 'encoder_dim': 512, 'rnn_hidden_size': 1024, 'aux_layer_period': 0, 'blank_id': 0, 'vocab_size': 500} +2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:235] About to create model +2023-02-17 11:22:43,239 INFO [train.py:472] Disable giga +2023-02-17 11:22:43,249 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/epoch-99.pt +2023-02-17 11:22:44,595 INFO [export-for-ncnn.py:324] encoder parameters: 83137520 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:325] decoder parameters: 257024 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:326] joiner parameters: 781812 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:327] total parameters: 84176356 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:329] Using torch.jit.trace() +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:331] Exporting encoder +2023-02-17 11:22:48,182 INFO [export-for-ncnn.py:158] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt +2023-02-17 11:22:48,183 INFO [export-for-ncnn.py:335] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/lstm_transducer_stateless2/decoder.py:101: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + need_pad = bool(need_pad) +2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:180] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt +2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:339] Exporting joiner +2023-02-17 11:22:48,304 INFO [export-for-ncnn.py:207] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..25874a414 --- /dev/null +++ b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt @@ -0,0 +1,74 @@ +2023-02-27 20:23:07,473 INFO [export-for-ncnn.py:246] device: cpu +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:255] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-clean', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp'), 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': '2,4,3,2,4', 'feedforward_dims': '1024,1024,2048,2048,1024', 'nhead': '8,8,8,8,8', 'encoder_dims': '384,384,384,384,384', 'attention_dims': '192,192,192,192,192', 'encoder_unmasked_dims': '256,256,256,256,256', 'zipformer_downsampling_factors': '1,2,4,8,2', 'cnn_module_kernels': '31,31,31,31,31', 'decoder_dim': 512, 'joiner_dim': 512, 'short_chunk_size': 50, 'num_left_chunks': 4, 'decode_chunk_len': 32, 'blank_id': 0, 'vocab_size': 500} +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:257] About to create model +2023-02-27 20:23:08,023 INFO [zipformer2.py:419] At encoder stack 4, which has downsampling_factor=2, we will combine the outputs of layers 1 and 3, with downsampling_factors=2 and 8. +2023-02-27 20:23:08,037 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/epoch-99.pt +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:346] encoder parameters: 68944004 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:347] decoder parameters: 260096 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:348] joiner parameters: 716276 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:349] total parameters: 69920376 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:351] Using torch.jit.trace() +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:353] Exporting encoder +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:174] decode_chunk_len: 32 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:175] T: 39 +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1344: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_len.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_avg.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1352: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1356: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1360: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1364: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv1.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1368: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1373: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.left_context_len == cached_key.shape[1], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1884: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.x_size == x.size(0), (self.x_size, x.size(0)) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2442: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2449: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == cached_val.shape[0], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2469: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2473: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2483: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert kv_len == k.shape[0], (kv_len, k.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2570: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2926: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2652: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2653: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2666: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1543: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1637: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1643: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1571: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + if src.shape[0] != self.in_x_size: +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1763: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1779: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1780: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) +/star-fj/fangjun/py38/lib/python3.8/site-packages/torch/jit/_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior. + module._c._create_method_from_trace( +2023-02-27 20:23:19,640 INFO [export-for-ncnn.py:182] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,646 INFO [export-for-ncnn.py:357] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py:102: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert embedding_out.size(-1) == self.context_size +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:204] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:361] Exporting joiner +2023-02-27 20:23:19,735 INFO [export-for-ncnn.py:231] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt new file mode 100644 index 000000000..347e7e51a --- /dev/null +++ b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt @@ -0,0 +1,104 @@ +Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1 +num encoder conv layers: 88 +num joiner conv layers: 3 +num files: 3 +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +----------encoder---------- +conv_87 : max = 15.942385 threshold = 15.938493 scale = 7.968131 +conv_88 : max = 35.442448 threshold = 15.549335 scale = 8.167552 +conv_89 : max = 23.228289 threshold = 8.001738 scale = 15.871552 +linear_90 : max = 3.976146 threshold = 1.101789 scale = 115.267128 +linear_91 : max = 6.962030 threshold = 5.162033 scale = 24.602713 +linear_92 : max = 12.323041 threshold = 3.853959 scale = 32.953129 +linear_94 : max = 6.905416 threshold = 4.648006 scale = 27.323545 +linear_93 : max = 6.905416 threshold = 5.474093 scale = 23.200188 +linear_95 : max = 1.888012 threshold = 1.403563 scale = 90.483986 +linear_96 : max = 6.856741 threshold = 5.398679 scale = 23.524273 +linear_97 : max = 9.635942 threshold = 2.613655 scale = 48.590950 +linear_98 : max = 6.460340 threshold = 5.670146 scale = 22.398010 +linear_99 : max = 9.532276 threshold = 2.585537 scale = 49.119396 +linear_101 : max = 6.585871 threshold = 5.719224 scale = 22.205809 +linear_100 : max = 6.585871 threshold = 5.751382 scale = 22.081648 +linear_102 : max = 1.593344 threshold = 1.450581 scale = 87.551147 +linear_103 : max = 6.592681 threshold = 5.705824 scale = 22.257959 +linear_104 : max = 8.752957 threshold = 1.980955 scale = 64.110489 +linear_105 : max = 6.696240 threshold = 5.877193 scale = 21.608953 +linear_106 : max = 9.059659 threshold = 2.643138 scale = 48.048950 +linear_108 : max = 6.975461 threshold = 4.589567 scale = 27.671457 +linear_107 : max = 6.975461 threshold = 6.190381 scale = 20.515701 +linear_109 : max = 3.710759 threshold = 2.305635 scale = 55.082436 +linear_110 : max = 7.531228 threshold = 5.731162 scale = 22.159557 +linear_111 : max = 10.528083 threshold = 2.259322 scale = 56.211544 +linear_112 : max = 8.148807 threshold = 5.500842 scale = 23.087374 +linear_113 : max = 8.592566 threshold = 1.948851 scale = 65.166611 +linear_115 : max = 8.437109 threshold = 5.608947 scale = 22.642395 +linear_114 : max = 8.437109 threshold = 6.193942 scale = 20.503904 +linear_116 : max = 3.966980 threshold = 3.200896 scale = 39.676392 +linear_117 : max = 9.451303 threshold = 6.061664 scale = 20.951344 +linear_118 : max = 12.077262 threshold = 3.965800 scale = 32.023804 +linear_119 : max = 9.671615 threshold = 4.847613 scale = 26.198460 +linear_120 : max = 8.625638 threshold = 3.131427 scale = 40.556595 +linear_122 : max = 10.274080 threshold = 4.888716 scale = 25.978189 +linear_121 : max = 10.274080 threshold = 5.420480 scale = 23.429659 +linear_123 : max = 4.826197 threshold = 3.599617 scale = 35.281532 +linear_124 : max = 11.396383 threshold = 7.325849 scale = 17.335875 +linear_125 : max = 9.337198 threshold = 3.941410 scale = 32.221970 +linear_126 : max = 9.699965 threshold = 4.842878 scale = 26.224073 +linear_127 : max = 8.775370 threshold = 3.884215 scale = 32.696438 +linear_129 : max = 9.872276 threshold = 4.837319 scale = 26.254213 +linear_128 : max = 9.872276 threshold = 7.180057 scale = 17.687883 +linear_130 : max = 4.150427 threshold = 3.454298 scale = 36.765789 +linear_131 : max = 11.112692 threshold = 7.924847 scale = 16.025545 +linear_132 : max = 11.852893 threshold = 3.116593 scale = 40.749626 +linear_133 : max = 11.517084 threshold = 5.024665 scale = 25.275314 +linear_134 : max = 10.683807 threshold = 3.878618 scale = 32.743618 +linear_136 : max = 12.421055 threshold = 6.322729 scale = 20.086264 +linear_135 : max = 12.421055 threshold = 5.309880 scale = 23.917679 +linear_137 : max = 4.827781 threshold = 3.744595 scale = 33.915554 +linear_138 : max = 14.422395 threshold = 7.742882 scale = 16.402161 +linear_139 : max = 8.527538 threshold = 3.866123 scale = 32.849449 +linear_140 : max = 12.128619 threshold = 4.657793 scale = 27.266134 +linear_141 : max = 9.839593 threshold = 3.845993 scale = 33.021378 +linear_143 : max = 12.442304 threshold = 7.099039 scale = 17.889746 +linear_142 : max = 12.442304 threshold = 5.325038 scale = 23.849592 +linear_144 : max = 5.929444 threshold = 5.618206 scale = 22.605080 +linear_145 : max = 13.382126 threshold = 9.321095 scale = 13.625010 +linear_146 : max = 9.894987 threshold = 3.867645 scale = 32.836517 +linear_147 : max = 10.915313 threshold = 4.906028 scale = 25.886522 +linear_148 : max = 9.614287 threshold = 3.908151 scale = 32.496181 +linear_150 : max = 11.724932 threshold = 4.485588 scale = 28.312899 +linear_149 : max = 11.724932 threshold = 5.161146 scale = 24.606939 +linear_151 : max = 7.164453 threshold = 5.847355 scale = 21.719223 +linear_152 : max = 13.086471 threshold = 5.984121 scale = 21.222834 +linear_153 : max = 11.099524 threshold = 3.991601 scale = 31.816805 +linear_154 : max = 10.054585 threshold = 4.489706 scale = 28.286930 +linear_155 : max = 12.389185 threshold = 3.100321 scale = 40.963501 +linear_157 : max = 9.982999 threshold = 5.154796 scale = 24.637253 +linear_156 : max = 9.982999 threshold = 8.537706 scale = 14.875190 +linear_158 : max = 8.420287 threshold = 6.502287 scale = 19.531588 +linear_159 : max = 25.014746 threshold = 9.423280 scale = 13.477261 +linear_160 : max = 45.633553 threshold = 5.715335 scale = 22.220921 +linear_161 : max = 20.371849 threshold = 5.117830 scale = 24.815203 +linear_162 : max = 12.492933 threshold = 3.126283 scale = 40.623318 +linear_164 : max = 20.697504 threshold = 4.825712 scale = 26.317358 +linear_163 : max = 20.697504 threshold = 5.078367 scale = 25.008038 +linear_165 : max = 9.023975 threshold = 6.836278 scale = 18.577358 +linear_166 : max = 34.860619 threshold = 7.259792 scale = 17.493614 +linear_167 : max = 30.380934 threshold = 5.496160 scale = 23.107042 +linear_168 : max = 20.691216 threshold = 4.733317 scale = 26.831076 +linear_169 : max = 9.723948 threshold = 3.952728 scale = 32.129707 +linear_171 : max = 21.034811 threshold = 5.366547 scale = 23.665123 +linear_170 : max = 21.034811 threshold = 5.356277 scale = 23.710501 +linear_172 : max = 10.556884 threshold = 5.729481 scale = 22.166058 +linear_173 : max = 20.033039 threshold = 10.207264 scale = 12.442120 +linear_174 : max = 11.597379 threshold = 2.658676 scale = 47.768131 +----------joiner---------- +linear_2 : max = 19.293503 threshold = 14.305265 scale = 8.877850 +linear_1 : max = 10.812222 threshold = 8.766452 scale = 14.487047 +linear_3 : max = 0.999999 threshold = 0.999755 scale = 127.031174 +ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233... diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt new file mode 100644 index 000000000..d39215b14 --- /dev/null +++ b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt @@ -0,0 +1,44 @@ +Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1 +num encoder conv layers: 28 +num joiner conv layers: 3 +num files: 3 +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +----------encoder---------- +conv_15 : max = 15.942385 threshold = 15.930708 scale = 7.972025 +conv_16 : max = 44.978855 threshold = 17.031788 scale = 7.456645 +conv_17 : max = 17.868437 threshold = 7.830528 scale = 16.218575 +linear_18 : max = 3.107259 threshold = 1.194808 scale = 106.293236 +linear_19 : max = 6.193777 threshold = 4.634748 scale = 27.401705 +linear_20 : max = 9.259933 threshold = 2.606617 scale = 48.722160 +linear_21 : max = 5.186600 threshold = 4.790260 scale = 26.512129 +linear_22 : max = 9.759041 threshold = 2.265832 scale = 56.050053 +linear_23 : max = 3.931209 threshold = 3.099090 scale = 40.979767 +linear_24 : max = 10.324160 threshold = 2.215561 scale = 57.321835 +linear_25 : max = 3.800708 threshold = 3.599352 scale = 35.284134 +linear_26 : max = 10.492444 threshold = 3.153369 scale = 40.274391 +linear_27 : max = 3.660161 threshold = 2.720994 scale = 46.674126 +linear_28 : max = 9.415265 threshold = 3.174434 scale = 40.007133 +linear_29 : max = 4.038418 threshold = 3.118534 scale = 40.724262 +linear_30 : max = 10.072084 threshold = 3.936867 scale = 32.259155 +linear_31 : max = 4.342712 threshold = 3.599489 scale = 35.282787 +linear_32 : max = 11.340535 threshold = 3.120308 scale = 40.701103 +linear_33 : max = 3.846987 threshold = 3.630030 scale = 34.985939 +linear_34 : max = 10.686298 threshold = 2.204571 scale = 57.607586 +linear_35 : max = 4.904821 threshold = 4.575518 scale = 27.756420 +linear_36 : max = 11.806659 threshold = 2.585589 scale = 49.118401 +linear_37 : max = 6.402340 threshold = 5.047157 scale = 25.162680 +linear_38 : max = 11.174589 threshold = 1.923361 scale = 66.030258 +linear_39 : max = 16.178576 threshold = 7.556058 scale = 16.807705 +linear_40 : max = 12.901954 threshold = 5.301267 scale = 23.956539 +linear_41 : max = 14.839805 threshold = 7.597429 scale = 16.716181 +linear_42 : max = 10.178945 threshold = 2.651595 scale = 47.895699 +----------joiner---------- +linear_2 : max = 24.829245 threshold = 16.627592 scale = 7.637907 +linear_1 : max = 10.746186 threshold = 5.255032 scale = 24.167313 +linear_3 : max = 1.000000 threshold = 0.999756 scale = 127.031013 +ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233... diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt new file mode 100644 index 000000000..114fe7342 --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-01-11 14:02:12,216 INFO [streaming-ncnn-decode.py:320] {'tokens': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav'} +T 51 32 +2023-01-11 14:02:13,141 INFO [streaming-ncnn-decode.py:328] Constructing Fbank computer +2023-01-11 14:02:13,151 INFO [streaming-ncnn-decode.py:331] Reading sound files: ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav +2023-01-11 14:02:13,176 INFO [streaming-ncnn-decode.py:336] torch.Size([106000]) +2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:380] ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav +2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:381] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt new file mode 100644 index 000000000..3606eae3d --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt @@ -0,0 +1,6 @@ +2023-02-17 11:37:30,861 INFO [streaming-ncnn-decode.py:255] {'tokens': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav'} +2023-02-17 11:37:31,425 INFO [streaming-ncnn-decode.py:263] Constructing Fbank computer +2023-02-17 11:37:31,427 INFO [streaming-ncnn-decode.py:266] Reading sound files: ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav +2023-02-17 11:37:31,431 INFO [streaming-ncnn-decode.py:271] torch.Size([106000]) +2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:342] ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav +2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:343] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt new file mode 100644 index 000000000..5b4969e0f --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-02-27 20:43:40,283 INFO [streaming-ncnn-decode.py:349] {'tokens': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav'} +2023-02-27 20:43:41,260 INFO [streaming-ncnn-decode.py:357] Constructing Fbank computer +2023-02-27 20:43:41,264 INFO [streaming-ncnn-decode.py:360] Reading sound files: ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:41,269 INFO [streaming-ncnn-decode.py:365] torch.Size([106000]) +2023-02-27 20:43:41,280 INFO [streaming-ncnn-decode.py:372] number of states: 35 +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:410] ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:411] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst new file mode 100644 index 000000000..12b370143 --- /dev/null +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -0,0 +1,753 @@ +.. _export_conv_emformer_transducer_models_to_ncnn: + +Export ConvEmformer transducer models to ncnn +============================================= + +We use the pre-trained model from the following repository as an example: + + - ``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +--------------------------------- + +.. hint:: + + You can also refer to ``_ to download the pre-trained model. + + You have to install `git-lfs`_ before you continue. + +.. code-block:: bash + + cd egs/librispeech/ASR + + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + + git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. + +.. _export_for_ncnn_install_ncnn_and_pnnx: + +2. Install ncnn and pnnx +------------------------ + +.. code-block:: bash + + # We put ncnn into $HOME/open-source/ncnn + # You can change it to anywhere you like + + cd $HOME + mkdir -p open-source + cd open-source + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=ON \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is $HOME/open-source/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH + + # Now build pnnx + cd tools/pnnx + mkdir build + cd build + cmake .. + make -j4 + + ./src/pnnx + +Congratulations! You have successfully installed the following components: + + - ``pnnx``, which is an executable located in + ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use + it to convert models exported by ``torch.jit.trace()``. + - ``ncnn2int8``, which is an executable located in + ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use + it to quantize our models to ``int8``. + - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located + in ``$HOME/open-source/ncnn/python/ncnn``. + + .. note:: + + I am using ``Python 3.8``, so it + is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different + version, say, ``Python 3.9``, the name would be + ``ncnn.cpython-39-x86_64-linux-gnu.so``. + + Also, if you are not using Linux, the file name would also be different. + But that does not matter. As long as you can compile it, it should work. + +We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your +Python code. We have also set up ``PATH`` so that you can use +``pnnx`` and ``ncnn2int8`` later in your terminal. + +.. caution:: + + Please don't use ``_. + We have made some modifications to the offical `ncnn`_. + + We will synchronize ``_ periodically + with the official one. + +3. Export the model via torch.jit.trace() +----------------------------------------- + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp + + ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ + + ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + --encoder-dim 512 + +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + +.. hint:: + + We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt + + The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + You can see that the file size of the pre-trained model is ``289 MB``, which + is roughly equal to ``75490012*4/1024/1024 = 287.97 MB``. + +After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* + + -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt + + +.. _conv-emformer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +------------------------------------ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 283 MB vs 142 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +-------------------------------------- + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + + +.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +---------------------------------------------- + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``1060 1342``, the first number ``1060`` specifies the number of layers + in this file, while ``1342`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. + We don't need to change ``1342`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``2=32``, 2 is the key and 32 is the value of the + parameter ``--memory-size`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``3=31``, 3 is the key and 31 is the value of the + parameter ``--cnn-module-kernel`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``4=8``, 4 is the key and 8 is the value of the + parameter ``--left-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``5=32``, 5 is the key and 32 is the value of the + parameter ``--chunk-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``6=8``, 6 is the key and 8 is the value of the + parameter ``--right-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``7=512``, 7 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 1 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--memory-size`` | + +------+-----------------------------+ + | 3 | ``--cnn-module-kernel`` | + +------+-----------------------------+ + | 4 | ``--left-context-length`` | + +------+-----------------------------+ + | 5 | ``--chunk-length`` | + +------+-----------------------------+ + | 6 | ``--right-context-length`` | + +------+-----------------------------+ + | 7 | ``--encoder-dim`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``1060`` to ``1061``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +7. (Optional) int8 quantization with sherpa-ncnn +------------------------------------------------ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`conv-emformer-step-4-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not + support quantizing the decoder model yet. We will update this documentation + once `ncnn`_ supports it. (Maybe in this year, 2023). + +It will generate the following files + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository ``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt + +It generates the following two files: + +.. code-block:: bash + + $ ls -lh encoder-scale-table.txt joiner-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: bash + + -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + +The following table compares again the file sizes: + + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file sizes of the model after ``int8`` quantization +are much smaller. + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + +You can find the speed comparison at ``_. + + +That's it! Have fun with `sherpa-ncnn`_! diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst new file mode 100644 index 000000000..8e6dc7466 --- /dev/null +++ b/docs/source/model-export/export-ncnn-lstm.rst @@ -0,0 +1,644 @@ +.. _export_lstm_transducer_models_to_ncnn: + +Export LSTM transducer models to ncnn +------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp + + ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + ./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --num-encoder-layers 12 \ + --encoder-dim 512 \ + --rnn-hidden-size 1024 + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-lstm-transducer-for-ncnn-output.txt + + The log shows the model has ``84176356`` parameters, i.e., ``~84 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt + + -rw-r--r-- 1 kuangfangjun root 324M Feb 17 10:34 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt + + You can see that the file size of the pre-trained model is ``324 MB``, which + is roughly equal to ``84176356*4/1024/1024 = 321.107 MB``. + +After running ``lstm_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1010K Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 318M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt + + +.. _lstm-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 159M Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param + + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 159 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 318 MB vs 159 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-lstm-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _lstm-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 267 379 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``267 379``, the first number ``267`` specifies the number of layers + in this file, while ``379`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 268 379 + SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``268 379``, we have added an extra layer, so we need to update ``267`` to ``268``. + We don't need to change ``379`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=3``, 0 is the key and 3 is the value. MUST be ``0=3`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + - ``2=512``, 2 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + - ``3=1024``, 3 is the key and 1024 is the value of the + parameter ``--rnn-hidden-size`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 3 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--encoder-dim`` | + +------+-----------------------------+ + | 3 | ``--rnn-hidden-size`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``267`` to ``268``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +7. (Optional) int8 quantization with sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`lstm-transducer-step-4-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not + support quantizing the decoder model yet. We will update this documentation + once `ncnn`_ supports it. (Maybe in this year, 2023). + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 317M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param + + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 267 379 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 268 379 + SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository +``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-lstm.txt + +It generates the following two files: + +.. code-block:: bash + + ls -lh encoder-scale-table.txt joiner-scale-table.txt + + -rw-r--r-- 1 kuangfangjun root 345K Feb 17 12:13 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 17K Feb 17 12:13 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: + + -rw-r--r-- 1 kuangfangjun root 218M Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + +The following table compares again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 218 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file size of the joiner model after ``int8`` quantization +is much smaller. However, the size of the encoder model is even larger than +the ``fp16`` counterpart. The reason is that `ncnn`_ currently does not support +quantizing ``LSTM`` layers into ``8-bit``. Please see +``_ + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + + +That's it! Have fun with `sherpa-ncnn`_! diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst new file mode 100644 index 000000000..5c81d25ca --- /dev/null +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -0,0 +1,383 @@ +.. _export_streaming_zipformer_transducer_models_to_ncnn: + +Export streaming Zipformer transducer models to ncnn +---------------------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + + ln -s pretrained.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --exp-dir $dir/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-zipformer-transducer-for-ncnn-output.txt + + The log shows the model has ``69920376`` parameters, i.e., ``~69.9 M``. + + .. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + -rw-r--r-- 1 kuangfangjun root 269M Jan 12 12:53 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + + You can see that the file size of the pre-trained model is ``269 MB``, which + is roughly equal to ``69920376*4/1024/1024 = 266.725 MB``. + +After running ``pruned_transducer_stateless7_streaming/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1022K Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 266M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 2.8M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt + +.. _zipformer-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 509K Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 133M Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 152K Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.4M Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 266 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1022 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 2.8 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 133 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 509 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.4 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 266 MB vs 133 MB + - decoder: 1022 KB vs 509 KB + - joiner: 2.8 MB vs 1.4 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _zipformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 2028 2547 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``2028 2547``, the first number ``2028`` specifies the number of layers + in this file, while ``2547`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 2029 2547 + SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``2029 2547``, we have added an extra layer, so we need to update ``2028`` to ``2029``. + We don't need to change ``2547`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=2``, 0 is the key and 2 is the value. MUST be ``0=2`` + - ``1=32``, 1 is the key and 32 is the value of the + parameter ``--decode-chunk-len`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``2=4``, 2 is the key and 4 is the value of the + parameter ``--num-left-chunks`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``3=7``, 3 is the key and 7 is the value of for the amount of padding + used in the Conv2DSubsampling layer. It should be 7 for zipformer + if you don't change zipformer.py. + - ``-23316=5,2,4,3,2,4``, attribute 16, this is an array attribute. + It is attribute 16 since -23300 - (-23316) = 16. + The first element of the array is the length of the array, which is 5 in our case. + ``2,4,3,2,4`` is the value of ``--num-encoder-layers``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23317=5,384,384,384,384,384``, attribute 17. + The first element of the array is the length of the array, which is 5 in our case. + ``384,384,384,384,384`` is the value of ``--encoder-dims``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23318=5,192,192,192,192,192``, attribute 18. + The first element of the array is the length of the array, which is 5 in our case. + ``192,192,192,192,192`` is the value of ``--attention-dims`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23319=5,1,2,4,8,2``, attribute 19. + The first element of the array is the length of the array, which is 5 in our case. + ``1,2,4,8,2`` is the value of ``--zipformer-downsampling-factors`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23320=5,31,31,31,31,31``, attribute 20. + The first element of the array is the length of the array, which is 5 in our case. + ``31,31,31,31,31`` is the value of ``--cnn-module-kernels`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +----------+--------------------------------------------+ + | key | value | + +==========+============================================+ + | 0 | 2 (fixed) | + +----------+--------------------------------------------+ + | 1 | ``-decode-chunk-len`` | + +----------+--------------------------------------------+ + | 2 | ``--num-left-chunks`` | + +----------+--------------------------------------------+ + | 3 | 7 (if you don't change code) | + +----------+--------------------------------------------+ + |-23316 | ``--num-encoder-layer`` | + +----------+--------------------------------------------+ + |-23317 | ``--encoder-dims`` | + +----------+--------------------------------------------+ + |-23318 | ``--attention-dims`` | + +----------+--------------------------------------------+ + |-23319 | ``--zipformer-downsampling-factors`` | + +----------+--------------------------------------------+ + |-23320 | ``--cnn-module-kernels`` | + +----------+--------------------------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``2028`` to ``2029``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 3dbb8b514..9eb5f85d2 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,12 +1,37 @@ Export to ncnn ============== -We support exporting LSTM transducer models to `ncnn `_. +We support exporting the following models +to `ncnn `_: -Please refer to :ref:`export-model-for-ncnn` for details. + - `Zipformer transducer models `_ -We also provide ``_ -performing speech recognition using ``ncnn`` with exported models. -It has been tested on Linux, macOS, Windows, and Raspberry Pi. The project is -self-contained and can be statically linked to produce a binary containing -everything needed. + - `LSTM transducer models `_ + + - `ConvEmformer transducer models `_ + +We also provide `sherpa-ncnn`_ +for performing speech recognition using `ncnn`_ with exported models. +It has been tested on the following platforms: + + - Linux + - macOS + - Windows + - ``Android`` + - ``iOS`` + - ``Raspberry Pi`` + - `爱芯派 `_ (`MAIX-III AXera-Pi `_). + - `RV1126 `_ + +`sherpa-ncnn`_ is self-contained and can be statically linked to produce +a binary containing everything needed. Please refer +to its documentation for details: + + - ``_ + + +.. toctree:: + + export-ncnn-zipformer + export-ncnn-conv-emformer + export-ncnn-lstm diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index dd4b3437a..aa77204cb 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -1,69 +1,95 @@ Export to ONNX ============== -In this section, we describe how to export models to ONNX. +In this section, we describe how to export models to `ONNX`_. + +In each recipe, there is a file called ``export-onnx.py``, which is used +to export trained models to `ONNX`_. + +There is also a file named ``onnx_pretrained.py``, which you can use +the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files. + +sherpa-onnx +----------- + +We have a separate repository `sherpa-onnx`_ for deploying your exported models +on various platforms such as: + + - iOS + - Android + - Raspberry Pi + - Linux/macOS/Windows + + +Please see the documentation of `sherpa-onnx`_ for details: + + ``_ + +Example +------- + +In the following, we demonstrate how to export a streaming Zipformer pre-trained +model from +``_ +to `ONNX`_. + +Download the pre-trained model +------------------------------ .. hint:: - Only non-streaming conformer transducer models are tested. - - -When to use it --------------- - -It you want to use an inference framework that supports ONNX -to run the pretrained model. - - -How to export -------------- - -We use -``_ -as an example in the following. + We assume you have installed `git-lfs`_. .. code-block:: bash - cd egs/librispeech/ASR - epoch=14 - avg=2 - ./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch $epoch \ - --avg $avg \ - --onnx 1 + cd egs/librispeech/ASR -It will generate the following files inside ``pruned_transducer_stateless3/exp``: + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) - - ``encoder.onnx`` - - ``decoder.onnx`` - - ``joiner.onnx`` - - ``joiner_encoder_proj.onnx`` - - ``joiner_decoder_proj.onnx`` + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained.pt" + cd exp + ln -s pretrained.pt epoch-99.pt + popd -You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode -waves with the generated files: +Export the model to ONNX +------------------------ .. code-block:: bash - ./pruned_transducer_stateless3/onnx_pretrained.py \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ - /path/to/foo.wav \ - /path/to/bar.wav \ - /path/to/baz.wav + ./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ +.. warning:: -How to use the exported model ------------------------------ + ``export-onnx.py`` from different recipes has different options. -We also provide ``_ -performing speech recognition using `onnxruntime `_ -with exported models. -It has been tested on Linux, macOS, and Windows. + In the above example, ``--decode-chunk-len`` is specific for the + streaming Zipformer. Other models won't have such an option. + +It will generate the following 3 files in ``$repo/exp`` + + - ``encoder-epoch-99-avg-1.onnx`` + - ``decoder-epoch-99-avg-1.onnx`` + - ``joiner-epoch-99-avg-1.onnx`` + +Decode sound files with exported ONNX models +-------------------------------------------- + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst index a041dc1d5..efd7dc2e1 100644 --- a/docs/source/model-export/export-with-torch-jit-script.rst +++ b/docs/source/model-export/export-with-torch-jit-script.rst @@ -1,7 +1,7 @@ .. _export-model-with-torch-jit-script: Export model with torch.jit.script() -=================================== +==================================== In this section, we describe how to export a model via ``torch.jit.script()``. diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst similarity index 99% rename from docs/source/recipes/aishell/conformer_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst index 72690e102..6e30ce397 100644 --- a/docs/source/recipes/aishell/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst @@ -703,7 +703,7 @@ It will show you the following message: HLG decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg similarity index 100% rename from docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg diff --git a/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg similarity index 100% rename from docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg diff --git a/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png similarity index 100% rename from docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst similarity index 99% rename from docs/source/recipes/aishell/index.rst rename to docs/source/recipes/Non-streaming-ASR/aishell/index.rst index d072d6e9c..b77d59bca 100644 --- a/docs/source/recipes/aishell/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst @@ -19,4 +19,3 @@ It can be downloaded from ``_ tdnn_lstm_ctc conformer_ctc stateless_transducer - diff --git a/docs/source/recipes/aishell/stateless_transducer.rst b/docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst similarity index 100% rename from docs/source/recipes/aishell/stateless_transducer.rst rename to docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst similarity index 100% rename from docs/source/recipes/aishell/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst diff --git a/docs/source/recipes/Non-streaming-ASR/index.rst b/docs/source/recipes/Non-streaming-ASR/index.rst new file mode 100644 index 000000000..67123a648 --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/index.rst @@ -0,0 +1,10 @@ +Non Streaming ASR +================= + +.. toctree:: + :maxdepth: 2 + + aishell/index + librispeech/index + timit/index + yesno/index diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst similarity index 99% rename from docs/source/recipes/librispeech/conformer_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst index 4656acfd6..b7f89c89f 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst @@ -888,7 +888,7 @@ It will show you the following message: CTC decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash @@ -926,7 +926,7 @@ Its output is: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION HLG decoding -^^^^^^^^^^^^ +~~~~~~~~~~~~ .. code-block:: bash @@ -966,7 +966,7 @@ The output is: HLG decoding + n-gram LM rescoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash @@ -1012,7 +1012,7 @@ The output is: HLG decoding + n-gram LM rescoring + attention decoder rescoring -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: bash diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst new file mode 100644 index 000000000..ea9f350cd --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst @@ -0,0 +1,223 @@ +Distillation with HuBERT +======================== + +This tutorial shows you how to perform knowledge distillation in `icefall`_ +with the `LibriSpeech`_ dataset. The distillation method +used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD). +Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_ +for more details about MVQ-KD. + +.. note:: + + This tutorial is based on recipe + `pruned_transducer_stateless4 `_. + Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes + with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you + encounter any problems, please open an issue here `icefall `_. + +.. note:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for `icefall`_. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +Data preparation +---------------- + +We first prepare necessary training data for `LibriSpeech`_. +This is the same as in :ref:`non_streaming_librispeech_pruned_transducer_stateless`. + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to :ref:`codebook_index_preparation` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0 + $ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech`_ + dataset and the `musan`_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +.. _codebook_index_preparation: + +Codebook index preparation +-------------------------- + +Here, we prepare necessary data for MVQ-KD. This requires the generation +of codebook indexes (please read our `paper `_. +if you are interested in details). In this tutorial, we use the pre-computed +codebook indexes for convenience. The only thing you need to do is to +run `./distillation_with_hubert.sh `_. + +.. note:: + + There are 5 stages in total, the first and second stage will be automatically skipped + when choosing to downloaded codebook indexes prepared by `icefall`_. + Of course, you can extract and compute the codebook indexes by yourself. This + will require you downloading a HuBERT-XL model and it can take a while for + the extraction of codebook indexes. + + +As usual, you can control the stages you want to run by specifying the following +two options: + + - ``--stage`` + - ``--stop-stage`` + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0 + $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5 + +Here are a few options in `./distillation_with_hubert.sh `_ +you need to know before you proceed. + +- ``--full_libri`` If True, use full 960h data. Otherwise only ``train-clean-100`` will be used +- ``--use_extracted_codebook`` If True, the first two stages will be skipped and the codebook + indexes uploaded by us will be downloaded. + +Since we are using the pre-computed codebook indexes, we set +``use_extracted_codebook=True``. If you want to do full `LibriSpeech`_ +experiments, please set ``full_libri=True``. + +The following command downloads the pre-computed codebook indexes +and prepares MVQ-augmented training manifests. + +.. code-block:: bash + + $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2 + +Please see the +following screenshot for the output of an example execution. + +.. figure:: ./images/distillation_codebook.png + :width: 800 + :alt: Downloading codebook indexes and preparing training manifest. + :align: center + + Downloading codebook indexes and preparing training manifest. + +.. hint:: + + The codebook indexes we prepared for you in this tutorial + are extracted from the 36-th layer of a fine-tuned HuBERT-XL model + with 8 codebooks. If you want to try other configurations, please + set ``use_extracted_codebook=False`` and set ``embedding_layer`` and + ``num_codebooks`` by yourself. + +Now, you should see the following files under the directory ``./data/vq_fbank_layer36_cb8``. + +.. figure:: ./images/distillation_directory.png + :width: 800 + :alt: MVQ-augmented training manifests + :align: center + + MVQ-augmented training manifests. + +Whola! You are ready to perform knowledge distillation training now! + +Training +-------- + +To perform training, please run stage 3 by executing the following command. + +.. code-block:: bash + + $ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training + +Here is the code snippet for training: + +.. code-block:: bash + + WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') + + ./pruned_transducer_stateless6/train.py \ + --manifest-dir ./data/vq_fbank_layer36_cb8 \ + --master-port 12359 \ + --full-libri $full_libri \ + --spec-aug-time-warp-factor -1 \ + --max-duration 300 \ + --world-size ${WORLD_SIZE} \ + --num-epochs 30 \ + --exp-dir $exp_dir \ + --enable-distillation True \ + --codebook-loss-scale 0.01 + +There are a few training arguments in the following +training commands that should be paid attention to. + + - ``--enable-distillation`` If True, knowledge distillation training is enabled. + - ``--codebook-loss-scale`` The scale of the knowledge distillation loss. + - ``--manifest-dir`` The path to the MVQ-augmented manifest. + + +Decoding +-------- + +After training finished, you can test the performance on using +the following command. + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES=0 + ./pruned_transducer_stateless6/train.py \ + --decoding-method "modified_beam_search" \ + --epoch 30 \ + --avg 10 \ + --max-duration 200 \ + --exp-dir $exp_dir \ + --enable-distillation True + +You should get similar results as `here `_. + +That's all! Feel free to experiment with your own setups and report your results. +If you encounter any problems during training, please open up an issue `here `_. diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png new file mode 100644 index 000000000..1a40d6c6e Binary files /dev/null and b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png differ diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png new file mode 100644 index 000000000..30763046f Binary files /dev/null and b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png differ diff --git a/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png similarity index 100% rename from docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png rename to docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg new file mode 100644 index 000000000..800835749 Binary files /dev/null and b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg differ diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst new file mode 100644 index 000000000..bf439861a --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst @@ -0,0 +1,12 @@ +LibriSpeech +=========== + +.. toctree:: + :maxdepth: 1 + + tdnn_lstm_ctc + conformer_ctc + pruned_transducer_stateless + zipformer_mmi + zipformer_ctc_blankskip + distillation diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst new file mode 100644 index 000000000..42fd3df77 --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -0,0 +1,548 @@ +.. _non_streaming_librispeech_pruned_transducer_stateless: + +Pruned transducer statelessX +============================ + +This tutorial shows you how to run a conformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_, + We will take pruned_transducer_stateless4 as an example in this tutorial. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless4/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless4/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless4/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless4/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless4/train.py`` directly. + + +.. NOTE:: + + The options for `pruned_transducer_stateless5 `_ are a little different from + other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless4/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/ + + [2022-11-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/librispeech-pruned-transducer-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 6 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" + ./pruned_transducer_stateless4/train.py \ + --world-size 6 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --max-duration 300 + + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/decode.py --help + +shows the options for decoding. + +The following shows two examples (for two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. Note:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + + +Export Model +------------ + +`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 25 --avg 3 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless4/decode.py) + + epoch=25 + avg=3 + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg + +It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless4/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless4/pretrained.py \ + --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 25 \ + --avg 3 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +.. NOTE:: + + You will need this ``cpu_jit.pt`` when deploying with Sherpa framework. + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless `_ + + - `pruned_transducer_stateless2 `_ + + - `pruned_transducer_stateless4 `_ + + - `pruned_transducer_stateless5 `_ + + See ``_ + for the details of the above pretrained models + + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst similarity index 100% rename from docs/source/recipes/librispeech/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst new file mode 100644 index 000000000..aa73bfe33 --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -0,0 +1,454 @@ +Zipformer CTC Blank Skip +======================== + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +This tutorial shows you how to train a Zipformer model based on the guidance from +a co-trained CTC model using `blank skip method `_ +with the `LibriSpeech `_ dataset. + +.. note:: + + We use both CTC and RNN-T loss to train. During the forward pass, the encoder output + is first used to calculate the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some threshold, it will be simply discarded + from the encoder output. To prevent information loss, we also put a convolution module + similar to the one used in conformer (referred to as “LConv”) before the frame reduction. + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +.. note:: + + We encourage you to read ``./prepare.sh``. + +The data preparation contains several stages. You can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. hint:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. note:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + +Training +-------- + +For stability, it doesn`t use blank skip method until model warm-up. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless7_ctc_bs/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless7_ctc_bs/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless7_ctc_bs/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless7_ctc_bs/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless7_ctc_bs/train.py`` directly. + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``pruned_transducer_stateless7_ctc_bs/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard + $ tensorboard dev upload --logdir . --description "Zipformer-CTC co-training using blank skip for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/ + + Note there is a URL in the above output. Click it and you will see + tensorboard. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --use-fp16 1 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help + +shows the options for decoding. + +The following shows the example using ``epoch-*.pt``: + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 30 \ + --avg 13 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method $m + done + +To test CTC branch, you can use the following command: + +.. code-block:: bash + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 30 \ + --avg 13 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method $m + done + +Export models +------------- + +`pruned_transducer_stateless7_ctc_bs/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless7_ctc_bs/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless7_ctc_bs/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 0 + +It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless7_ctc_bs/exp + ln -s pretrained epoch-9999.pt + + And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``. + +To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +To test CTC branch using the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +To use the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav + +To test CTC branch using the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - trained on LibriSpeech 100h: ``_ + - trained on LibriSpeech 960h: ``_ + + See ``_ + for the details of the above pretrained models diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst new file mode 100644 index 000000000..a7b59a992 --- /dev/null +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst @@ -0,0 +1,422 @@ +Zipformer MMI +=============== + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +This tutorial shows you how to train an Zipformer MMI model +with the `LibriSpeech `_ dataset. + +We use LF-MMI to compute the loss. + +.. note:: + + You can find the document about LF-MMI training at the following address: + + ``_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +.. note:: + + We encourage you to read ``./prepare.sh``. + +The data preparation contains several stages. You can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. hint:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. note:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + +Training +-------- + +For stability, it uses CTC loss for model warm-up and then switches to MMI loss. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./zipformer_mmi/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./zipformer_mmi/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./zipformer_mmi/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./zipformer_mmi/train.py --start-epoch 10`` loads the + checkpoint ``./zipformer_mmi/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./zipformer_mmi/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./zipformer_mmi/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./zipformer_mmi/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`zipformer_mmi/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./zipformer_mmi/train.py`` directly. + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``zipformer_mmi/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./zipformer_mmi/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./zipformer_mmi/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd zipformer_mmi/exp/tensorboard + $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/ + + Note there is a URL in the above output. Click it and you will see + tensorboard. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd zipformer_mmi/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./zipformer_mmi/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --full-libri 1 \ + --exp-dir zipformer_mmi/exp \ + --max-duration 500 \ + --use-fp16 1 \ + --num-workers 2 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``zipformer_mmi/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``zipformer_mmi/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./zipformer_mmi/decode.py --help + +shows the options for decoding. + +The following shows the example using ``epoch-*.pt``: + +.. code-block:: bash + + for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do + ./zipformer_mmi/decode.py \ + --epoch 30 \ + --avg 10 \ + --exp-dir ./zipformer_mmi/exp/ \ + --max-duration 100 \ + --lang-dir data/lang_bpe_500 \ + --nbest-scale 1.2 \ + --hp-scale 1.0 \ + --decoding-method $m + done + + +Export models +------------- + +`zipformer_mmi/export.py `_ supports exporting checkpoints from ``zipformer_mmi/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``zipformer_mmi/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + ./zipformer_mmi/export.py \ + --exp-dir ./zipformer_mmi/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 0 + +It will generate a file ``./zipformer_mmi/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``zipformer_mmi/decode.py``, + you can run: + + .. code-block:: bash + + cd zipformer_mmi/exp + ln -s pretrained epoch-9999.pt + + And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to + ``./zipformer_mmi/decode.py``. + +To use the exported model with ``./zipformer_mmi/pretrained.py``, you +can run: + +.. code-block:: bash + + ./zipformer_mmi/pretrained.py \ + --checkpoint ./zipformer_mmi/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method 1best \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./zipformer_mmi/export.py \ + --exp-dir ./zipformer_mmi/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +To use the generated files with ``./zipformer_mmi/jit_pretrained.py``: + +.. code-block:: bash + + ./zipformer_mmi/jit_pretrained.py \ + --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method 1best \ + /path/to/foo.wav \ + /path/to/bar.wav + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - ``_ + + See ``_ + for the details of the above pretrained models diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/Non-streaming-ASR/timit/index.rst similarity index 98% rename from docs/source/recipes/timit/index.rst rename to docs/source/recipes/Non-streaming-ASR/timit/index.rst index 17f40cdb7..5ee147be7 100644 --- a/docs/source/recipes/timit/index.rst +++ b/docs/source/recipes/Non-streaming-ASR/timit/index.rst @@ -6,4 +6,3 @@ TIMIT tdnn_ligru_ctc tdnn_lstm_ctc - diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst similarity index 97% rename from docs/source/recipes/timit/tdnn_ligru_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst index 186420ee7..3d7aefe02 100644 --- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst @@ -148,10 +148,10 @@ Some commonly used options are: $ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17 - uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, - ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, - ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, - ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, + uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``, + ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``, + ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``, + ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_ligru_ctc/pretrained.py + ./tdnn_ligru_ctc/pretrained.py --method 1best - --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt - --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt - --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt + --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt + --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -337,7 +337,7 @@ The output is: 2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started 2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 20:41:39,829 INFO [pretrained.py:267] + 2021-11-08 20:41:39,829 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh @@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \ --G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.1 \ - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -378,7 +378,7 @@ The decoding output is: 2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started 2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:37:56,348 INFO [pretrained.py:267] + 2021-11-08 20:37:56,348 INFO [pretrained.py:267] ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst similarity index 97% rename from docs/source/recipes/timit/tdnn_lstm_ctc.rst rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst index 6f760a9ce..ee67a6edc 100644 --- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst @@ -148,8 +148,8 @@ Some commonly used options are: $ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10 - uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, - ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, + uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, + ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt`` for decoding. @@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use: .. code-block:: bash - ./tdnn_lstm_ctc/pretrained.py + ./tdnn_lstm_ctc/pretrained.py --method 1best - --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt - --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt - --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt + --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt + --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The output is: @@ -335,7 +335,7 @@ The output is: 2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started 2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding - 2021-11-08 21:02:54,387 INFO [pretrained.py:267] + 2021-11-08 21:02:54,387 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh @@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.08 \ - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV - ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV + ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV The decoding output is: @@ -376,7 +376,7 @@ The decoding output is: 2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV'] 2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started 2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring - 2021-11-08 20:05:27,878 INFO [pretrained.py:267] + 2021-11-08 20:05:27,878 INFO [pretrained.py:267] ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV: sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh diff --git a/docs/source/recipes/yesno/images/tdnn-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png similarity index 100% rename from docs/source/recipes/yesno/images/tdnn-tensorboard-log.png rename to docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png diff --git a/docs/source/recipes/yesno/index.rst b/docs/source/recipes/Non-streaming-ASR/yesno/index.rst similarity index 100% rename from docs/source/recipes/yesno/index.rst rename to docs/source/recipes/Non-streaming-ASR/yesno/index.rst diff --git a/docs/source/recipes/yesno/tdnn.rst b/docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst similarity index 100% rename from docs/source/recipes/yesno/tdnn.rst rename to docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst diff --git a/docs/source/recipes/Streaming-ASR/index.rst b/docs/source/recipes/Streaming-ASR/index.rst new file mode 100644 index 000000000..8c0ffe447 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/index.rst @@ -0,0 +1,12 @@ +Streaming ASR +============= + +.. toctree:: + :maxdepth: 1 + + introduction + +.. toctree:: + :maxdepth: 2 + + librispeech/index diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst new file mode 100644 index 000000000..e1382e77d --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -0,0 +1,53 @@ +Introduction +============ + +This page shows you how we implement streaming **X-former transducer** models for ASR. + +.. HINT:: + X-former transducer here means the encoder of the transducer model uses Multi-Head Attention, + like `Conformer `_, `EmFormer `_ etc. + +Currently we have implemented two types of streaming models, one uses Conformer as encoder, the other uses Emformer as encoder. + +Streaming Conformer +------------------- + +The main idea of training a streaming model is to make the model see limited contexts +in training time, we can achieve this by applying a mask to the output of self-attention. +In icefall, we implement the streaming conformer the way just like what `WeNet `_ did. + +.. NOTE:: + The conformer-transducer recipes in LibriSpeech datasets, like, `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless3 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_ + all support streaming. + +.. NOTE:: + Training a streaming conformer model in ``icefall`` is almost the same as training a + non-streaming model, all you need to do is passing several extra arguments. + See :doc:`Pruned transducer statelessX ` for more details. + +.. HINT:: + If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer + to `this pull request `_. After adding the code needed by streaming training, + you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. + + +Streaming Emformer +------------------ + +The Emformer model proposed `here `_ uses more +complicated techniques. It has a memory bank component to memorize history information, +what' more, it also introduces right context in training time by hard-copying part of +the input features. + +We have three variants of Emformer models in ``icefall``. + + - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe `_. + - ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio, + ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model. + See `LibriSpeech recipe `_. + - ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that + it uses a simplified memory bank. See `LibriSpeech recipe `_. diff --git a/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png similarity index 100% rename from docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png rename to docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png diff --git a/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg new file mode 100644 index 000000000..9c77b8bae Binary files /dev/null and b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg differ diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst similarity index 61% rename from docs/source/recipes/librispeech/index.rst rename to docs/source/recipes/Streaming-ASR/librispeech/index.rst index 6c91b6750..d52e08058 100644 --- a/docs/source/recipes/librispeech/index.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/index.rst @@ -4,6 +4,8 @@ LibriSpeech .. toctree:: :maxdepth: 1 - tdnn_lstm_ctc - conformer_ctc + pruned_transducer_stateless + lstm_pruned_stateless_transducer + + zipformer_transducer diff --git a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst similarity index 81% rename from docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst rename to docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index 643855cc2..911e84656 100644 --- a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -515,110 +515,6 @@ To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: Please see ``_ for how to use the exported models in ``sherpa``. -.. _export-model-for-ncnn: - -Export model for ncnn -~~~~~~~~~~~~~~~~~~~~~ - -We support exporting pretrained LSTM transducer models to -`ncnn `_ using -`pnnx `_. - -First, let us install a modified version of ``ncnn``: - -.. code-block:: bash - - git clone https://github.com/csukuangfj/ncnn - cd ncnn - git submodule update --recursive --init - python3 setup.py bdist_wheel - ls -lh dist/ - pip install ./dist/*.whl - - # now build pnnx - cd tools/pnnx - mkdir build - cd build - make -j4 - export PATH=$PWD/src:$PATH - - ./src/pnnx - -.. note:: - - We assume that you have added the path to the binary ``pnnx`` to the - environment variable ``PATH``. - -Second, let us export the model using ``torch.jit.trace()`` that is suitable -for ``pnnx``: - -.. code-block:: bash - - iter=468000 - avg=16 - - ./lstm_transducer_stateless2/export.py \ - --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --iter $iter \ - --avg $avg \ - --pnnx 1 - -It will generate 3 files: - - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt`` - -Third, convert torchscript model to ``ncnn`` format: - -.. code-block:: - - pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt - pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt - pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt - -It will generate the following files: - - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin`` - -To use the above generated files, run: - -.. code-block:: bash - - ./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ - --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ - /path/to/foo.wav - -.. code-block:: bash - - ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ - --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ - /path/to/foo.wav - -To use the above generated files in C++, please see -``_ - -It is able to generate a static linked executable that can be run on Linux, Windows, -macOS, Raspberry Pi, etc, without external dependencies. - Download pretrained models -------------------------- diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst new file mode 100644 index 000000000..de7102ba8 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst @@ -0,0 +1,735 @@ +Pruned transducer statelessX +============================ + +This tutorial shows you how to run a **streaming** conformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless `_, + `pruned_transducer_stateless2 `_, + `pruned_transducer_stateless4 `_, + `pruned_transducer_stateless5 `_, + We will take pruned_transducer_stateless4 as an example in this tutorial. + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +.. NOTE:: + + We put the streaming and non-streaming model in one recipe, to train a streaming model you only + need to add **4** extra options comparing with training a non-streaming model. These options are + ``--dynamic-chunk-training``, ``--num-left-chunks``, ``--causal-convolution``, ``--short-chunk-size``. + You can see the configurable options below for their meanings or read https://arxiv.org/pdf/2012.05481.pdf for more details. + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless4/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless4/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless4/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + - ``--dynamic-chunk-training`` + + The flag that indicates whether to train a streaming model or not, it + **MUST** be True if you want to train a streaming model. + + - ``--short-chunk-size`` + + When training a streaming attention model with chunk masking, the chunk size + would be either max sequence length of current batch or uniformly sampled from + (1, short_chunk_size). The default value is 25, you don't have to change it most of the time. + + - ``--num-left-chunks`` + + It indicates how many left context (in chunks) that can be seen when calculating attention. + The default value is 4, you don't have to change it most of the time. + + + - ``--causal-convolution`` + + Whether to use causal convolution in conformer encoder layer, this requires + to be True when training a streaming model. + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless4/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless4/train.py`` directly. + + +.. NOTE:: + + The options for `pruned_transducer_stateless5 `_ are a little different from + other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless4/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless4/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/ + + [2022-11-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --max-duration 300 + +.. NOTE:: + + Comparing with training a non-streaming model, you only need to add two extra options, + ``--dynamic-chunk-training 1`` and ``--causal-convolution 1`` . + + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless4/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. tip:: + + To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or + ``real streaming decoding`` in ``streaming_decode.py``, the difference between ``decode.py`` and + ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training), + but ``streaming_decode.py`` processes the acoustic frames chunk by chunk (so it can only see limited context). + +.. NOTE:: + + ``simulate streaming decoding`` in ``decode.py`` and ``real streaming decoding`` in ``streaming_decode.py`` should + produce almost the same results given the same ``--decode-chunk-size`` and ``--left-context``. + + +Simulate streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--simulate-streaming`` + + If you want to decode a streaming model with ``decode.py``, you **MUST** set + ``--simulate-streaming`` to ``True``. ``simulate`` here means the acoustic frames + are not processed frame by frame (or chunk by chunk), instead, the whole sequence + is processed at one time with masking (the same as training). + + ``--causal-convolution`` + + If True, the convolution module in encoder layers will be causal convolution. + This is **MUST** be True when decoding with a streaming model. + + ``--decode-chunk-size`` + + For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size`` + indicates the chunk length (in frames after subsampling) for chunk-wise attention. + For ``simulate streaming decoding`` the ``decode-chunk-size`` is used to generate + the attention mask. + + ``--left-context`` + + ``--left-context`` indicates how many left context frames (after subsampling) can be seen + for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal + to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used + to train this model. For ``simulate streaming decoding`` the ``left-context`` is used to generate + the attention mask. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +Real streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless4/streaming_decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-size`` + + For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size`` + indicates the chunk length (in frames after subsampling) for chunk-wise attention. + For ``real streaming decoding``, we will process ``decode-chunk-size`` acoustic frames at each time. + + ``--left-context`` + + ``--left-context`` indicates how many left context frames (after subsampling) can be seen + for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal + to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used + to train this model. + + ``--num-decode-streams`` + + The number of decoding streams that can be run in parallel (very similar to the ``bath size``). + For ``real streaming decoding``, the batches will be packed dynamically, for example, if the + ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while, + suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch. + + +.. NOTE:: + + We also try adding ``--right-context`` in the real streaming decoding, but it seems not to benefit + the performance for all the models, the reasons might be the training and decoding mismatch. You + can try decoding with ``--right-context`` to see if it helps. The default value is 0. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 25 20; do + for avg in 7 5 3 1; do + ./pruned_transducer_stateless4/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-size 16 \ + --left-context 64 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless4/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-size 16 \ + --left-context 64 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. tip:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + +.. NOTE:: + + The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed, + you can implement them by yourself or file a issue in `icefall `_ . + + +Export Model +------------ + +`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 25 --avg 3 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless4/decode.py) + + epoch=25 + avg=3 + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --streaming-model 1 \ + --causal-convolution 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg + +.. caution:: + + ``--streaming-model`` and ``--causal-convolution`` require to be True to export + a streaming mdoel. + +It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless4/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless4/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless4/pretrained.py \ + --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless4/export.py \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --streaming-model 1 \ + --causal-convolution 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 25 \ + --avg 3 \ + --jit 1 + +.. caution:: + + ``--streaming-model`` and ``--causal-convolution`` require to be True to export + a streaming mdoel. + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +.. NOTE:: + + You will need this ``cpu_jit.pt`` when deploying with Sherpa framework. + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless `_ + + - `pruned_transducer_stateless2 `_ + + - `pruned_transducer_stateless4 `_ + + - `pruned_transducer_stateless5 `_ + + See ``_ + for the details of the above pretrained models + + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst new file mode 100644 index 000000000..f0e8961d7 --- /dev/null +++ b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst @@ -0,0 +1,654 @@ +Zipformer Transducer +==================== + +This tutorial shows you how to run a **streaming** zipformer transducer model +with the `LibriSpeech `_ dataset. + +.. Note:: + + The tutorial is suitable for `pruned_transducer_stateless7_streaming `_, + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + +.. HINT:: + + We recommend you to use a GPU or several GPUs to run this recipe. + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use a Zipformer model (proposed by Daniel Povey) + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + + +Data preparation +---------------- + +.. hint:: + + The data preparation is the same as other recipes on LibriSpeech dataset, + if you have finished this step, you can skip to ``Training`` directly. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. HINT:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. NOTE:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/train.py --help + + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--exp-dir`` + + The directory to save checkpoints, training logs and tensorboard. + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./pruned_transducer_stateless7_streaming/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./pruned_transducer_stateless7_streaming/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./pruned_transducer_stateless7_streaming/train.py --start-epoch 10`` loads the + checkpoint ``./pruned_transducer_stateless7_streaming/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./pruned_transducer_stateless7_streaming/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--use-fp16`` + + If it is True, the model will train with half precision, from our experiment + results, by using half precision you can train with two times larger ``--max-duration`` + so as to get almost 2X speed up. + + We recommend using ``--use-fp16 True``. + + - ``--short-chunk-size`` + + When training a streaming attention model with chunk masking, the chunk size + would be either max sequence length of current batch or uniformly sampled from + (1, short_chunk_size). The default value is 50, you don't have to change it most of the time. + + - ``--num-left-chunks`` + + It indicates how many left context (in chunks) that can be seen when calculating attention. + The default value is 4, you don't have to change it most of the time. + + + - ``--decode-chunk-len`` + + The chunk size for decoding (in frames before subsampling). It is used for validation. + The default value is 32 (i.e., 320ms). + + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., number of encoder layers, +encoder dimension, decoder dimension, number of warmup steps etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`pruned_transducer_stateless7_streaming/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./pruned_transducer_stateless7_streaming/train.py`` directly. + + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless7_streaming/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_streaming/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./pruned_transducer_stateless7_streaming/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd pruned_transducer_stateless7_streaming/exp/tensorboard + $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall" + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd pruned_transducer_stateless7_streaming/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 4 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``pruned_transducer_stateless7_streaming/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``pruned_transducer_stateless7_streaming/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. tip:: + + To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or + ``real chunk-wise streaming decoding`` in ``streaming_decode.py``. The difference between ``decode.py`` and + ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training), + but ``streaming_decode.py`` processes the acoustic frames chunk by chunk. + +.. NOTE:: + + ``simulate streaming decoding`` in ``decode.py`` and ``real chunk-size streaming decoding`` in ``streaming_decode.py`` should + produce almost the same results given the same ``--decode-chunk-len``. + + +Simulate streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-len`` + + It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling). + The default value is 32 (i.e., 320ms). + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 30; do + for avg in 12 11 10 9 8; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-len 32 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decoding-method $m + done + done + done + + +Real streaming decoding +~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./pruned_transducer_stateless7_streaming/streaming_decode.py --help + +shows the options for decoding. +The following options are important for streaming models: + + ``--decode-chunk-len`` + + It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling). + The default value is 32 (i.e., 320ms). + For ``real streaming decoding``, we will process ``decode-chunk-len`` acoustic frames at each time. + + ``--num-decode-streams`` + + The number of decoding streams that can be run in parallel (very similar to the ``bath size``). + For ``real streaming decoding``, the batches will be packed dynamically, for example, if the + ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while, + suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch. + + +The following shows two examples (for the two types of checkpoints): + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 30; do + for avg in 12 11 10 9 8; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch $epoch \ + --avg $avg \ + --decode-chunk-len 32 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --iter $iter \ + --avg $avg \ + --decode-chunk-len 16 \ + --num-decode-streams 100 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m + done + done + done + + +.. tip:: + + Supporting decoding methods are as follows: + + - ``greedy_search`` : It takes the symbol with largest posterior probability + of each frame as the decoding result. + + - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and + `espnet/nets/beam_search_transducer.py `_ + is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to + next frame. + + - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it + runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded. + + - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and + given ``FSAs``. It is hard to describe the details in several lines of texts, you can read + our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently. + + - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses + an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph + (with N-gram LM). + + - ``fast_beam_search_nbest`` : It produces the decoding results as follows: + + - (1) Use ``fast_beam_search`` to get a lattice + - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()`` + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the + only difference is that it uses ``fast_beam_search_LG`` to generate the lattice. + +.. NOTE:: + + The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed, + you can implement them by yourself or file a issue in `icefall `_ . + + +Export Model +------------ + +Currently it supports exporting checkpoints from ``pruned_transducer_stateless7_streaming/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``pruned_transducer_stateless7_streaming/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --epoch 30 --avg 9 produces the smallest WER + # (You can get such information after running ./pruned_transducer_stateless7_streaming/decode.py) + + epoch=30 + avg=9 + + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +It will generate a file ``./pruned_transducer_stateless7_streaming/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_streaming/decode.py``, + you can run: + + .. code-block:: bash + + cd pruned_transducer_stateless7_streaming/exp + ln -s pretrained.pt epoch-999.pt + + And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to + ``./pruned_transducer_stateless7_streaming/decode.py``. + +To use the exported model with ``./pruned_transducer_stateless7_streaming/pretrained.py``, you +can run: + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +Export model using ``torch.jit.script()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --decode-chunk-len 32 \ + --jit 1 + +.. caution:: + + ``--decode-chunk-len`` is required to export a ScriptModule. + +It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later +load it by ``torch.jit.load("cpu_jit.pt")``. + +Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python +are on CPU. You can use ``to("cuda")`` to move them to a CUDA device. + +Export model using ``torch.jit.trace()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + epoch=30 + avg=9 + + ./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model=True \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch $epoch \ + --avg $avg + +.. caution:: + + ``--decode-chunk-len`` is required to export a ScriptModule. + +It will generate 3 files: + + - ``./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt`` + - ``./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt`` + - ``./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt`` + +To use the generated files with ``./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py``: + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + /path/to/foo.wav + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - `pruned_transducer_stateless7_streaming `_ + + See ``_ + for the details of the above pretrained models + +Deploy with Sherpa +------------------ + +Please see ``_ +for how to deploy the models in ``sherpa``. diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 9d1d83d29..63793275c 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -13,7 +13,5 @@ We may add recipes for other tasks as well in the future. :maxdepth: 2 :caption: Table of Contents - aishell/index - librispeech/index - timit/index - yesno/index + Non-streaming-ASR/index + Streaming-ASR/index diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index fb2751c0f..387c14acf 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_char.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py index 71be2a613..85047c367 100755 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ b/egs/aidatatang_200zh/ASR/local/text2token.py @@ -56,9 +56,7 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +64,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +104,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +130,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 039951354..46ecd5769 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 @@ -106,11 +109,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_char_dir/words.txt ]; then ./local/prepare_words.py \ --input-file $lang_char_dir/words_no_ids.txt \ - --output-file $lang_char_dir/words.txt + --output-file $lang_char_dir/words.txt fi if [ ! -f $lang_char_dir/L_disambig.pt ]; then ./local/prepare_char.py fi fi - diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 6a5b57e24..167d5e15e 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index f0407f429..2512f233f 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -69,11 +69,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -192,8 +188,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +244,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +259,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -390,9 +380,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -403,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -424,10 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index 00b54c39f..e348f7b2b 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -103,8 +103,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +172,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index eb5e6b0d4..75c316eaf 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -162,8 +162,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +192,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +255,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +280,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +332,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index d46838b68..c9d9c4aa8 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,8 +185,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -211,8 +208,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -542,22 +538,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +700,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +800,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 75fc6326e..f4a59e552 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,7 +1,7 @@ # Introduction -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 9a06fbe9f..aa18502c2 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,6 +2,57 @@ ### Aishell training result(Stateless Transducer) +#### Pruned transducer stateless 7 (zipformer) + +See + +[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe) + +**Note**: The modeling units are byte level BPEs + +The best results I have gotten are: + +Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments +-- | -- | -- | -- | -- | -- +500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29 + +The training command: + +``` +export CUDA_VISIBLE_DEVICES="4,5,6,7" + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --max-duration 800 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --lr-epochs 6 \ + --master-port 12535 +``` + +The decoding command: + +``` +for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do + ./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 48 \ + --avg 29 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-sym-per-frame 1 \ + --ngram-lm-scale 0.25 \ + --ilme-scale 0.2 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 2000 \ + --decoding-method $m +done +``` + +The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + + #### Pruned transducer stateless 3 See @@ -15,6 +66,8 @@ It uses pruned RNN-T. |------------------------|------|------|---------------------------------------| | greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | | modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 | | fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | Training command is: @@ -73,6 +126,78 @@ for epoch in 29; do done ``` +We provide the option of shallow fusion with a RNN language model. The pre-trained language model is +available at . To decode with the language model, +please use the following command: + +```bash +# download pre-trained model +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + +aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/ + +pushd ${aishell_exp}/exp +ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt +popd + +# download RNN LM +git lfs install +git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm +rnnlm_dir=icefall-aishell-rnn-lm + +# RNNLM shallow fusion +for lm_scale in $(seq 0.26 0.02 0.34); do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 +done + +# RNNLM Low-order density ratio (LODR) with a 2-gram + +cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt + +for lm_scale in 0.48; do + for LODR_scale in -0.28; do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_LODR \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 \ + --tokens-ngram 2 \ + --backoff-id 4336 \ + --ngram-lm-scale $LODR_scale + done +done + +``` + Pretrained models, training logs, decoding logs, and decoding results are available at diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index cb7205e51..ab1cbbae4 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 751b7d5b5..74a7b5933 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -401,9 +401,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -431,9 +429,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -441,9 +437,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -562,9 +556,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py index 42b8c29e7..1df3cfdc2 100644 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -157,9 +157,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 27776bc24..66d583396 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -210,10 +210,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -274,9 +273,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -371,9 +368,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py index e3361d0c9..81fa234dd 100755 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py @@ -16,9 +16,8 @@ # limitations under the License. -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index a228cc1fe..c2cbe6e3b 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -382,9 +382,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -520,9 +518,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -630,9 +626,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py index cb7205e51..ab1cbbae4 100644 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ b/egs/aishell/ASR/conformer_mmi/conformer.py @@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 4db367e36..20a855e7f 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -413,9 +413,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -443,9 +441,7 @@ def save_results( # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=enable_log @@ -453,9 +449,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" @@ -550,9 +544,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -581,9 +573,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py index 720ed6c22..398837a46 100644 --- a/egs/aishell/ASR/conformer_mmi/subsampling.py +++ b/egs/aishell/ASR/conformer_mmi/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 685831d09..09cd6e60c 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -511,9 +511,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +623,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py index f93914aaa..a3e50e385 100644 --- a/egs/aishell/ASR/conformer_mmi/transformer.py +++ b/egs/aishell/ASR/conformer_mmi/transformer.py @@ -149,9 +149,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -183,9 +181,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,23 +262,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -343,23 +333,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -632,9 +616,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -836,9 +818,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -859,9 +839,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/aishell/ASR/local/compile_lg.py b/egs/aishell/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/aishell/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 42700a972..037971927 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -116,9 +114,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index deab6c809..115ca1031 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index d9e47d17a..8cc0502c2 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`: - tokens.txt """ +import argparse import re from pathlib import Path from typing import Dict, List @@ -86,9 +87,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +141,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: @@ -193,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]: return tokens +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + return parser.parse_args() + + def main(): - lang_dir = Path("data/lang_char") + args = get_args() + lang_dir = Path(args.lang_dir) text_file = lang_dir / "text" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py new file mode 100755 index 000000000..e7995680b --- /dev/null +++ b/egs/aishell/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `tokens.txt` and a text file such as +./download/lm/aishell-transcript.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_char. The format is as follows: +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the same format with librispeech receipe +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-char", + type=str, + help="""Lang dir of asr model, e.g. data/lang_char""", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/lm/aishell-train-word.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + # make token_dict from tokens.txt in order to map characters to tokens. + token_dict = {} + token_file = args.lang_char + "/tokens.txt" + + with open(token_file, "r") as f: + for line in f.readlines(): + line_list = line.split() + token_dict[line_list[0]] = int(line_list[1]) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of tokens. + word2index = dict() + + word2token = [] # Will be a list-of-list-of-int, representing tokens. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "aishell-lm" in args.lm_data: + num_lines_in_total = 120098.0 + step = 50000 + elif "valid" in args.lm_data: + num_lines_in_total = 14326.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 7176.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed / num_lines_in_total * 100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_token = [] + for t in w: + if t in token_dict: + w_token.append(token_dict[t]) + else: + w_token.append(token_dict[""]) + word2index[w] = len(word2token) + word2token.append(w_token) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2token) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell/ASR/local/prepare_lang_bbpe.py b/egs/aishell/ASR/local/prepare_lang_bbpe.py new file mode 100755 index 000000000..ddd90622e --- /dev/null +++ b/egs/aishell/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bbpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.byte_utils import byte_encode +from icefall.utils import str2bool, tokenize_by_CJK_char + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words] + words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bbpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", args.oov, "#0", "", ""] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/local/sort_lm_training_data.py b/egs/aishell/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/aishell/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell/ASR/local/test_prepare_lang.py +++ b/egs/aishell/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py new file mode 100755 index 000000000..d231d5d77 --- /dev/null +++ b/egs/aishell/ASR/local/train_bbpe_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import re +import shutil +import tempfile +from pathlib import Path + +import sentencepiece as spm +from icefall import byte_encode, tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def _convert_to_bchar(in_path: str, out_path: str): + with open(out_path, "w") as f: + for line in open(in_path, "r").readlines(): + f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n") + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + temp = tempfile.NamedTemporaryFile() + train_text = temp.name + + _convert_to_bchar(args.transcript, train_text) + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index eaeecfc4a..b763d72c1 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -1,10 +1,13 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 stage=-1 -stop_stage=10 +stop_stage=11 # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -32,6 +35,15 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 +# vocab size for sentence piece models. +# It will generate data/lang_bbpe_xxx, +# data/lang_bbpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 2000 + # 1000 + 500 +) + # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -44,20 +56,6 @@ log() { log "dl_dir: $dl_dir" -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" - # We assume that you have installed the git-lfs, if not, you could install it - # using: `sudo apt-get install git-lfs && git-lfs install` - git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) - - if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then - git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm - pushd $dl_dir/lm - git lfs pull --include "3-gram.unpruned.arpa" - popd - fi -fi - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "stage 0: Download data" @@ -131,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi lang_phone_dir=data/lang_phone -lang_char_dir=data/lang_char if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" mkdir -p $lang_phone_dir @@ -180,39 +177,194 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi +lang_char_dir=data/lang_char if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Prepare char based lang" mkdir -p $lang_char_dir # We reuse words.txt from phone based lexicon # so that the two can share G.pt later. - cp $lang_phone_dir/words.txt $lang_char_dir + + # The transcripts in training set, generated in stage 5 + cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | - cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text + cut -d " " -f 2- > $lang_char_dir/text + + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > $lang_char_dir/words.txt + + cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt + + num_lines=$(< $lang_char_dir/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> $lang_char_dir/words.txt if [ ! -f $lang_char_dir/L_disambig.pt ]; then - ./local/prepare_char.py + ./local/prepare_char.py --lang-dir $lang_char_dir fi fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare G" - # We assume you have install kaldilm, if not, please install - # it using: pip install kaldilm + log "Stage 7: Prepare Byte BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cp $lang_char_dir/words.txt $lang_dir + cp $lang_char_dir/text $lang_dir + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/transcript_words.txt \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then # It is used in building HLG python3 -m kaldilm \ --read-symbol-table="$lang_phone_dir/words.txt" \ --disambig-symbol='#0' \ --max-order=3 \ - $dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt + + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt fi fi -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compile HLG" - ./local/compile_hlg.py --lang-dir $lang_phone_dir - ./local/compile_hlg.py --lang-dir $lang_char_dir +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile LG & HLG" + ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char + done + + ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char + done +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Generate LM training data" + + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm + + if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then + cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt + fi + + # training words + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-train-word.txt \ + --lm-archive $out_dir/lm_data.pt + + # valid words + if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid + find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt + + # test words + if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid + find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi + + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. + + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 11: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index a12934d55..fb6c7c481 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -188,8 +184,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +244,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -263,10 +256,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -387,9 +377,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -400,24 +388,18 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -427,10 +409,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -473,9 +452,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -504,8 +481,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index feababdd2..2ce5cfe69 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -50,11 +50,7 @@ from pathlib import Path import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -157,8 +152,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -191,9 +185,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -201,17 +193,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index 3c38e5db7..82c10f129 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -165,8 +165,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -196,10 +195,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -256,13 +254,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -310,9 +304,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -329,9 +321,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 97d892754..d08908238 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -49,7 +49,6 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn - from asr_datamodule import AishellAsrDataModule from conformer import Conformer from decoder import Decoder @@ -75,9 +74,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -203,8 +200,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -227,8 +223,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -251,8 +246,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -561,11 +555,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -593,23 +583,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -725,9 +708,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -891,7 +872,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index d159e420b..27c64efaa 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -54,6 +54,40 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) modified beam search (with LM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(6) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.48 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.28 \ """ @@ -74,9 +108,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -202,8 +239,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -213,6 +249,60 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -224,6 +314,9 @@ def decode_one_batch( token_table: k2.SymbolTable, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -263,9 +356,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -277,10 +368,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -293,6 +381,24 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) else: hyp_tokens = [] batch_size = encoder_out.size(0) @@ -340,6 +446,9 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -385,6 +494,9 @@ def decode_dataset( token_table=token_table, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -401,9 +513,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -414,24 +524,18 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -441,10 +545,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -462,6 +563,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -475,6 +577,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -488,9 +592,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -498,6 +600,19 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -518,9 +633,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -551,9 +666,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -607,6 +722,35 @@ def main(): else: decoding_graph = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + lm_filename, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -629,6 +773,9 @@ def main(): model=model, token_table=lexicon.token_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 deleted file mode 120000 index bcd4abc2f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 +++ /dev/null @@ -1 +0,0 @@ -/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 566902a85..723414167 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -48,6 +48,7 @@ import logging from pathlib import Path import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -132,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -166,9 +166,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -195,9 +195,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -245,6 +245,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not @@ -252,9 +253,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) - filename = ( - params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - ) + filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") else: @@ -262,17 +261,14 @@ def main(): # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = ( - params.exp_dir - / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" ) torch.save({"model": model.state_dict()}, str(filename)) logging.info(f"Saved to {filename}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py new file mode 120000 index 000000000..557e18aa1 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index e150e8230..a4dda0d6d 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_datatang = decoder_datatang self.joiner_datatang = joiner_datatang - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_datatang is not None: @@ -179,9 +177,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index 04a0a882a..ead393e6e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -165,8 +165,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -196,10 +195,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,13 +255,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -311,9 +305,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -330,9 +322,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index feaef5cf6..62e67530d 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -96,9 +96,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -224,8 +222,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -248,8 +245,7 @@ def get_parser(): "--context-size", type=int, default=1, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -272,8 +268,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -635,11 +630,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -670,23 +661,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -824,9 +808,7 @@ def train_one_epoch( ) # summary stats if datatang_train_dl is not None: - tot_loss = ( - tot_loss * (1 - 1 / params.reset_interval) - ) + loss_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info if aishell: aishell_tot_loss = ( @@ -847,9 +829,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -892,9 +872,7 @@ def train_one_epoch( cur_lr = scheduler.get_last_lr()[0] if datatang_train_dl is not None: datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], " - tot_loss_str = ( - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - ) + tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, " else: tot_loss_str = "" datatang_str = "" @@ -1067,7 +1045,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1076,9 +1054,7 @@ def run(rank, world_size, args): train_cuts = filter_short_and_long_utterances(train_cuts) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -1093,9 +1069,7 @@ def run(rank, world_size, args): if params.datatang_prob > 0: datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) train_datatang_cuts = datatang.train_cuts() - train_datatang_cuts = filter_short_and_long_utterances( - train_datatang_cuts - ) + train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) train_datatang_cuts = train_datatang_cuts.repeat(times=None) datatang_train_dl = asr_datamodule.train_dataloaders( train_datatang_cuts, @@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py new file mode 100755 index 000000000..fcb0ebc4e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py @@ -0,0 +1,819 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_bbpe/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import ( + LmScorer, + NgramLm, + byte_encode, + smart_byte_decode, + tokenize_by_CJK_char, +) +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the byte BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.25, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + subtract_ilme=True, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + ref_texts = [] + for tx in supervisions["text"]: + ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) + + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(ref_texts), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + key += f"_ilme_scale_{params.ilme_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"-ilme-scale-{params.ilme_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py new file mode 100755 index 000000000..4e82b45d3 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/aishell/ASR + ./pruned_transducer_stateless7_bbpe/decode.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_500/bbpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe + # You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_bbpe/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py new file mode 100755 index 000000000..0c43bf74b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 49 \ + --avg 28 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_bbpe/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = smart_byte_decode(sp.decode(hyp)) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py new file mode 100755 index 000000000..ea5bda4db --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_bbpe/export.py \ + --exp-dir ./pruned_transducer_stateless7_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --epoch 48 \ + --avg 29 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./pruned_transducer_stateless7_bbpe/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \ + --bpe-model ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_bbpe/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import smart_byte_decode +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py new file mode 120000 index 000000000..7ceac5d10 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py new file mode 100755 index 000000000..499badb14 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -0,0 +1,1261 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --max-duration 400 + +# For mix precision training: + +./pruned_transducer_stateless7_bbpe/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_bbpe/exp \ + --max-duration 800 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_bbpe/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the Byte BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 2000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 12.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = aishell.valid_cuts() + valid_cuts = valid_cuts.map(tokenize_and_encode_text) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index d24ba6bb7..efb32336a 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,7 +21,7 @@ import inspect import logging from functools import lru_cache from pathlib import Path -from typing import List +from typing import Any, Dict, List, Optional from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( @@ -181,19 +181,24 @@ class AishellAsrDataModule: "with training dataset. ", ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -215,9 +220,7 @@ class AishellAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -260,9 +263,7 @@ class AishellAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -285,6 +286,10 @@ class AishellAsrDataModule: ) logging.info("About to create train dataloader") + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + train_dl = DataLoader( train, sampler=train_sampler, @@ -308,9 +313,7 @@ class AishellAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -335,7 +338,7 @@ class AishellAsrDataModule: return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") + logging.info("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats @@ -366,13 +369,9 @@ class AishellAsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_test.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 66b734fc4..824ca2a92 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -265,9 +265,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -289,9 +287,7 @@ def save_results( # We compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) test_set_wers[key] = wer @@ -335,9 +331,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -362,9 +356,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") model.to(device) model.eval() @@ -392,9 +384,7 @@ def main(): lexicon=lexicon, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py index 5e04c11b4..1731e1ebe 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py @@ -66,10 +66,7 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] + [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 9bd810809..7e7213501 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -53,9 +53,7 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") parser.add_argument( "--method", @@ -112,10 +110,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -173,9 +170,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = features.permute(0, 2, 1) # now features is [N, C, T] with torch.no_grad(): @@ -219,9 +214,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index 7619b0551..e574cf89b 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..de0a8d0f5 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -47,9 +47,9 @@ def greedy_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -81,9 +81,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -157,9 +157,7 @@ class HypothesisList(object): """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -246,9 +244,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index 64114253d..78424aea2 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 780b0c4bb..d23f4f883 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -99,8 +99,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -227,9 +226,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] batch_size = encoder_out.size(0) @@ -248,9 +245,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") hyps.append([lexicon.token_table[i] for i in hyp]) if params.decoding_method == "greedy_search": @@ -319,9 +314,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -332,23 +325,17 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -358,10 +345,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -430,9 +414,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index c2c6552a9..70e9e6c96 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -86,9 +86,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 4c6519b96..01de5d772 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -110,8 +110,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -243,9 +242,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 994305fc1..591bbe44f 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -103,9 +103,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index db89c4d67..40f430e13 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -117,8 +117,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -211,10 +210,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -273,9 +271,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -319,9 +315,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index d54157709..62ffff473 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -126,8 +126,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -389,9 +388,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -504,9 +501,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -625,9 +620,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py index e851dcc32..b3ff153c1 100644 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ b/egs/aishell/ASR/transducer_stateless/transformer.py @@ -250,9 +250,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index 838e53658..5d49d7338 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -29,10 +29,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -162,9 +159,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -173,9 +168,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -252,9 +245,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index ea3f94fd8..d164b6890 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -170,8 +170,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -227,9 +226,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -241,10 +238,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -365,9 +359,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -378,24 +370,18 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -405,10 +391,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -448,9 +431,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index 3bd2ceb11..c1081c32b 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -109,8 +109,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +240,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index a95a4bc52..5d8ca2e11 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -165,8 +165,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -194,10 +193,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +252,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +302,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +319,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 225d0d709..8fb7d1e49 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -149,8 +149,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -168,8 +167,7 @@ def get_parser(): "--datatang-prob", type=float, default=0.2, - help="The probability to select a batch from the " - "aidatatang_200zh dataset", + help="The probability to select a batch from the aidatatang_200zh dataset", ) return parser @@ -449,9 +447,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -605,9 +601,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) aishell_tot_loss.write_summary( tb_writer, "train/aishell_tot_", params.batch_idx_train ) @@ -735,9 +729,7 @@ def run(rank, world_size, args): train_datatang_cuts = train_datatang_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None @@ -776,9 +768,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 65fcda873..0a7d87fe8 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -171,8 +171,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -231,9 +230,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -245,10 +242,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -369,9 +363,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -382,24 +374,18 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: - results_char.append( - (res[0], list("".join(res[1])), list("".join(res[2]))) - ) + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -409,10 +395,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -452,9 +435,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 11335a834..3e14ad69c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -109,8 +109,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -241,9 +240,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 262e822c2..9e4459247 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -165,8 +165,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -194,10 +193,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,13 +252,9 @@ def main(): feature_lens = [f.size(0) for f in features] feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens) num_waves = encoder_out.size(0) hyp_list = [] @@ -308,9 +302,7 @@ def main(): beam=params.beam_size, ) else: - raise ValueError( - f"Unsupported decoding method: {params.method}" - ) + raise ValueError(f"Unsupported decoding method: {params.method}") hyp_list.append(hyp) hyps = [] @@ -327,9 +319,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index d3ffccafa..5f116f2bd 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -142,8 +142,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -414,9 +413,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -529,9 +526,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -657,9 +652,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index d8d3622bd..ec0c584ca 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -111,9 +109,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 06810bfdd..3e8e840ab 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=30 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py old mode 100755 new mode 100644 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py old mode 100755 new mode 100644 index b7a21f579..0f383a244 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -216,13 +216,9 @@ class AiShell2AsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -244,9 +240,7 @@ class AiShell2AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -290,9 +284,7 @@ class AiShell2AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -348,9 +340,7 @@ class AiShell2AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -406,9 +396,7 @@ class AiShell2AsrDataModule: @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz") - return load_manifest_lazy( - self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz" - ) + return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz") @lru_cache() def test_cuts(self) -> CutSet: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 915737f4a..9e44b4e34 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -269,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -348,9 +347,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -409,10 +406,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -538,9 +532,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -551,18 +543,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -572,10 +560,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -625,9 +610,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -661,9 +644,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -690,9 +673,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -749,9 +732,7 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index bc7bd71cb..8a5be94d0 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -167,9 +166,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -196,9 +195,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -266,9 +265,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index 09de1bece..bc3ae7abf 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -159,8 +159,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -191,10 +190,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -254,15 +252,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -334,9 +328,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 838a0497f..74bf68ccb 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -92,9 +92,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -220,8 +218,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -244,8 +241,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -268,8 +264,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -603,11 +598,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -636,23 +627,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -771,9 +755,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise if params.print_diagnostics and batch_idx == 5: @@ -829,9 +811,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -939,7 +919,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch( - batch, params=params, graph_compiler=graph_compiler - ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) raise diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 3f50d9e3e..400c406f0 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -120,9 +118,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/aishell4/ASR/local/prepare_char.py +++ b/egs/aishell4/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ b/egs/aishell4/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/aishell4/ASR/local/test_prepare_lang.py +++ b/egs/aishell4/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py index 71be2a613..85047c367 100755 --- a/egs/aishell4/ASR/local/text2token.py +++ b/egs/aishell4/ASR/local/text2token.py @@ -56,9 +56,7 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +64,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +104,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +130,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index c351e3964..cb2b73a3e 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index 7aa53ddda..d980a857f 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -222,17 +222,13 @@ class Aishell4AsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -254,9 +250,7 @@ class Aishell4AsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -300,9 +294,7 @@ class Aishell4AsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -359,9 +351,7 @@ class Aishell4AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 14e44c7d9..068e2749a 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -201,8 +201,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -260,9 +259,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -277,10 +274,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -401,9 +395,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -414,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -435,10 +423,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -480,9 +465,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -510,9 +493,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -543,9 +526,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index 993341131..bf9856c60 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -169,9 +168,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -202,9 +201,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -276,9 +275,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index 1fa893637..ee898c303 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,10 +203,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -266,15 +264,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -306,10 +300,7 @@ def main(): for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,9 +341,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 0a48b9059..d7c69f226 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -85,9 +85,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -213,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -237,8 +234,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -261,8 +257,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -599,11 +594,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -633,22 +624,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -827,9 +811,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -937,7 +919,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index af926aa53..96115a230 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -121,9 +119,7 @@ def get_args(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py index d9e47d17a..6b440dfb3 100755 --- a/egs/alimeeting/ASR/local/prepare_char.py +++ b/egs/alimeeting/ASR/local/prepare_char.py @@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [ - token2id[i] if i in token2id else token2id[""] for i in pieces - ] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon( - token_sym_table: Dict[str, int], words: List[str] -) -> Lexicon: +def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py index e5ae89ec4..c8cf9b881 100755 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ b/egs/alimeeting/ASR/local/prepare_lang.py @@ -317,9 +317,7 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", type=str, help="The lang dir, data/lang_phone" - ) + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") return parser.parse_args() diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/alimeeting/ASR/local/test_prepare_lang.py +++ b/egs/alimeeting/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py index 7c1019aa8..27b904fc8 100644 --- a/egs/alimeeting/ASR/local/text2segments.py +++ b/egs/alimeeting/ASR/local/text2segments.py @@ -30,8 +30,8 @@ with word segmenting: import argparse -import paddle import jieba +import paddle from tqdm import tqdm paddle.enable_static() diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py index 71be2a613..85047c367 100755 --- a/egs/alimeeting/ASR/local/text2token.py +++ b/egs/alimeeting/ASR/local/text2token.py @@ -56,9 +56,7 @@ def get_parser(): parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument( - "--space", default="", type=str, help="space symbol" - ) + parser.add_argument("--space", default="", type=str, help="space symbol") parser.add_argument( "--non-lang-syms", "-l", @@ -66,9 +64,7 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument( - "text", type=str, default=False, nargs="?", help="input text" - ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") parser.add_argument( "--trans_type", "-t", @@ -108,8 +104,7 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id - for txt in text + token_table[txt] if txt in token_table else oov_id for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -135,9 +130,7 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")( - sys.stdin if is_python2 else sys.stdin.buffer - ) + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 17224bb68..604cc92c6 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index bf6faad7a..a9a4675a9 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -205,17 +205,13 @@ class AlimeetingAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -237,9 +233,7 @@ class AlimeetingAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -282,9 +276,7 @@ class AlimeetingAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -341,9 +333,7 @@ class AlimeetingAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 6358fe970..6c170c392 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -70,11 +70,7 @@ from beam_search import ( from lhotse.cut import Cut from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -193,8 +189,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -249,9 +244,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -266,10 +259,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -390,9 +380,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -403,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -424,10 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -563,8 +544,7 @@ def main(): ) dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -574,8 +554,7 @@ def main(): ) test_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) + str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) ] cuts_test_webdataset = CutSet.from_webdataset( test_shards, @@ -588,9 +567,7 @@ def main(): return 1.0 <= c.duration cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt) - cuts_test_webdataset = cuts_test_webdataset.filter( - remove_short_and_long_utt - ) + cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt) dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset) test_dl = alimeeting.test_dataloaders(cuts_test_webdataset) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8beec1b8a..8e5cc6075 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -103,8 +103,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -173,9 +172,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index 93b1e1f57..f5a0dd8c8 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -162,8 +162,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -193,10 +192,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -257,9 +255,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -284,10 +280,7 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -339,9 +332,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 81a0ede7f..e57b5c859 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -81,9 +81,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] os.environ["CUDA_LAUNCH_BLOCKING"] = "1" @@ -187,8 +185,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -211,8 +208,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -542,22 +538,15 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -711,9 +700,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -813,7 +800,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR_v2/README.md b/egs/alimeeting/ASR_v2/README.md new file mode 100644 index 000000000..f70327501 --- /dev/null +++ b/egs/alimeeting/ASR_v2/README.md @@ -0,0 +1,38 @@ + +# Introduction + +This recipe trains multi-domain ASR models for AliMeeting. By multi-domain, we mean that +we train a single model on close-talk and far-field conditions. This recipe optionally +uses [GSS]-based enhancement for far-field array microphone. +We pool data in the following 4 ways and train a single model on the pooled data: + +(i) individual headset microphone (IHM) +(ii) IHM with simulated reverb +(iii) Single distant microphone (SDM) +(iv) GSS-enhanced array microphones + +This is different from `alimeeting/ASR` since that recipe trains a model only on the +far-field audio. Additionally, we use text normalization here similar to the original +M2MeT challenge, so the results should be more comparable to those from Table 4 of +the [paper](https://arxiv.org/abs/2110.07393). + +The following additional packages need to be installed to run this recipe: +* `pip install jieba` +* `pip install paddlepaddle` +* `pip install git+https://github.com/desh2608/gss.git` + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +## Performance Record + +### pruned_transducer_stateless7 + +The following are decoded using `modified_beam_search`: + +| Evaluation set | eval WER | test WER | +|--------------------------|------------|---------| +| IHM | 9.58 | 11.53 | +| SDM | 23.37 | 25.85 | +| MDM (GSS-enhanced) | 11.82 | 14.22 | + +See [RESULTS](/egs/alimeeting/ASR_v2/RESULTS.md) for details. diff --git a/egs/alimeeting/ASR_v2/RESULTS.md b/egs/alimeeting/ASR_v2/RESULTS.md new file mode 100644 index 000000000..15b24250d --- /dev/null +++ b/egs/alimeeting/ASR_v2/RESULTS.md @@ -0,0 +1,90 @@ +## Results (CER) + +#### 2022-12-09 + +#### Zipformer (pruned_transducer_stateless7) + +Zipformer encoder + non-current decoder. The decoder +contains only an embedding layer, a Conv1d (with kernel size 2) and a linear +layer (to transform tensor dim). + +All the results below are using a single model that is trained by combining the following +data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise +augmentation are applied on top of the pooled data. + +**WERs for IHM:** + +| | eval | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 10.13 | 12.21 | --epoch 15 --avg 8 --max-duration 500 | +| modified beam search | 9.58 | 11.53 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 9.92 | 12.07 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +**WERs for SDM:** + +| | eval | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 23.70 | 26.41 | --epoch 15 --avg 8 --max-duration 500 | +| modified beam search | 23.37 | 25.85 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 23.60 | 26.38 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +**WERs for GSS-enhanced MDM:** + +| | eval | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 12.24 | 14.99 | --epoch 15 --avg 8 --max-duration 500 | +| modified beam search | 11.82 | 14.22 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 12.30 | 14.98 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 15 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 300 \ + --max-cuts 100 \ + --prune-range 5 \ + --lr-factor 5 \ + --lm-scale 0.25 \ + --use-fp16 True +``` + +The decoding command is: +``` +# greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 15 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method greedy_search + +# modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 15 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 15 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + +Pretrained model is available at + +The tensorboard training log can be found at + diff --git a/egs/alimeeting/ASR_v2/local/__init__.py b/egs/alimeeting/ASR_v2/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py new file mode 100755 index 000000000..c6aa2ab36 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the AliMeeting dataset. +For the training data, we prepare IHM, reverberated IHM, SDM, and GSS-enhanced +audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced +parts (which are the 3 evaluation settings). +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" +import logging +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def compute_fbank_ami(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests_ihm = read_manifests_if_cached( + dataset_parts=["train", "eval", "test"], + output_dir=src_dir, + prefix="alimeeting-ihm", + suffix="jsonl.gz", + ) + manifests_sdm = read_manifests_if_cached( + dataset_parts=["train", "eval", "test"], + output_dir=src_dir, + prefix="alimeeting-sdm", + suffix="jsonl.gz", + ) + # For GSS we already have cuts so we read them directly. + manifests_gss = read_manifests_if_cached( + dataset_parts=["train", "eval", "test"], + output_dir=src_dir, + prefix="alimeeting-gss", + suffix="jsonl.gz", + ) + + def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None: + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) + _ = cuts.compute_and_store_features_batch( + extractor=extractor, + storage_path=storage_path, + manifest_path=manifest_path, + batch_duration=5000, + num_workers=8, + storage_type=LilcomChunkyWriter, + ) + + logging.info( + "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)" + ) + + logging.info("Processing train split IHM") + cuts_ihm = ( + CutSet.from_manifests(**manifests_ihm["train"]) + .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + .modify_ids(lambda x: x + "-ihm") + ) + _extract_feats( + cuts_ihm, + output_dir / "feats_train_ihm", + src_dir / "cuts_train_ihm.jsonl.gz", + ) + + logging.info("Processing train split IHM + reverberated IHM") + cuts_ihm_rvb = cuts_ihm.reverb_rir() + _extract_feats( + cuts_ihm_rvb, + output_dir / "feats_train_ihm_rvb", + src_dir / "cuts_train_ihm_rvb.jsonl.gz", + ) + + logging.info("Processing train split SDM") + cuts_sdm = ( + CutSet.from_manifests(**manifests_sdm["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-sdm") + ) + _extract_feats( + cuts_sdm, + output_dir / "feats_train_sdm", + src_dir / "cuts_train_sdm.jsonl.gz", + ) + + logging.info("Processing train split GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-gss") + ) + _extract_feats( + cuts_gss, + output_dir / "feats_train_gss", + src_dir / "cuts_train_gss.jsonl.gz", + ) + + logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") + for split in ["eval", "test"]: + logging.info(f"Processing {split} IHM") + cuts_ihm = ( + CutSet.from_manifests(**manifests_ihm[split]) + .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_ihm", + manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + logging.info(f"Processing {split} SDM") + cuts_sdm = ( + CutSet.from_manifests(**manifests_sdm[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_sdm", + manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + logging.info(f"Processing {split} GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_gss", + manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_ami() diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py new file mode 100644 index 000000000..f1512efa5 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py @@ -0,0 +1,158 @@ +#!/usr/local/bin/python +# -*- coding: utf-8 -*- +# Data preparation for AliMeeting GSS-enhanced dataset. + +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from lhotse import Recording, RecordingSet, SupervisionSet +from lhotse.qa import fix_manifests +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import fastcopy +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.") + parser.add_argument( + "manifests_dir", + type=Path, + help="Path to directory containing AliMeeting manifests.", + ) + parser.add_argument( + "enhanced_dir", + type=Path, + help="Path to enhanced data directory.", + ) + parser.add_argument( + "--num-jobs", + "-j", + type=int, + default=1, + help="Number of parallel jobs to run.", + ) + parser.add_argument( + "--min-segment-duration", + "-d", + type=float, + default=0.0, + help="Minimum duration of a segment in seconds.", + ) + return parser.parse_args() + + +def find_recording_and_create_new_supervision(enhanced_dir, supervision): + """ + Given a supervision (corresponding to original AMI recording), this function finds the + enhanced recording correspoding to the supervision, and returns this recording and + a new supervision whose start and end times are adjusted to match the enhanced recording. + """ + file_name = Path( + f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac" + ) + save_path = enhanced_dir / f"{supervision.recording_id}" / file_name + if save_path.exists(): + recording = Recording.from_file(save_path) + if recording.duration == 0: + logging.warning(f"Skipping {save_path} which has duration 0 seconds.") + return None + + # Old supervision is wrt to the original recording, we create new supervision + # wrt to the enhanced segment + new_supervision = fastcopy( + supervision, + recording_id=recording.id, + start=0, + duration=recording.duration, + ) + return recording, new_supervision + else: + logging.warning(f"{save_path} does not exist.") + return None + + +def main(args): + # Get arguments + manifests_dir = args.manifests_dir + enhanced_dir = args.enhanced_dir + + # Load manifests from cache if they exist (saves time) + manifests = read_manifests_if_cached( + dataset_parts=["train", "eval", "test"], + output_dir=manifests_dir, + prefix="alimeeting-sdm", + suffix="jsonl.gz", + ) + if not manifests: + raise ValueError( + "AliMeeting SDM manifests not found in {}".format(manifests_dir) + ) + + with ThreadPoolExecutor(args.num_jobs) as ex: + for part in ["train", "eval", "test"]: + logging.info(f"Processing {part}...") + supervisions_orig = manifests[part]["supervisions"].filter( + lambda s: s.duration >= args.min_segment_duration + ) + futures = [] + + for supervision in tqdm( + supervisions_orig, + desc="Distributing tasks", + ): + futures.append( + ex.submit( + find_recording_and_create_new_supervision, + enhanced_dir, + supervision, + ) + ) + + recordings = [] + supervisions = [] + for future in tqdm( + futures, + total=len(futures), + desc="Processing tasks", + ): + result = future.result() + if result is not None: + recording, new_supervision = result + recordings.append(recording) + supervisions.append(new_supervision) + + # Remove duplicates from the recordings + recordings_nodup = {} + for recording in recordings: + if recording.id not in recordings_nodup: + recordings_nodup[recording.id] = recording + else: + logging.warning("Recording {} is duplicated.".format(recording.id)) + recordings = RecordingSet.from_recordings(recordings_nodup.values()) + supervisions = SupervisionSet.from_segments(supervisions) + + recordings, supervisions = fix_manifests( + recordings=recordings, supervisions=supervisions + ) + + logging.info(f"Writing {part} enhanced manifests") + recordings.to_file( + manifests_dir / f"alimeeting-gss_recordings_{part}.jsonl.gz" + ) + supervisions.to_file( + manifests_dir / f"alimeeting-gss_supervisions_{part}.jsonl.gz" + ) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh new file mode 100755 index 000000000..76db19832 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# This script is used to run GSS-based enhancement on AMI data. +set -euo pipefail +nj=4 +stage=0 + +. shared/parse_options.sh || exit 1 + +if [ $# != 2 ]; then + echo "Wrong #arguments ($#, expected 2)" + echo "Usage: local/prepare_alimeeting_gss.sh [options] " + echo "e.g. local/prepare_alimeeting_gss.sh data/manifests exp/ami_gss" + echo "main options (for others, see top of script file)" + echo " --nj # number of parallel jobs" + echo " --stage # stage to start running from" + exit 1; +fi + +DATA_DIR=$1 +EXP_DIR=$2 + +mkdir -p $EXP_DIR + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 1 ]; then + log "Stage 1: Prepare cut sets" + for part in train eval test; do + lhotse cut simple \ + -r $DATA_DIR/alimeeting-mdm_recordings_${part}.jsonl.gz \ + -s $DATA_DIR/alimeeting-mdm_supervisions_${part}.jsonl.gz \ + $EXP_DIR/cuts_${part}.jsonl.gz + done +fi + +if [ $stage -le 2 ]; then + log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" + for part in train eval test; do + lhotse cut trim-to-supervisions --discard-overlapping \ + $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz + done +fi + +if [ $stage -le 3 ]; then + log "Stage 3: Split manifests for multi-GPU processing (optional)" + for part in train eval test; do + gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj + done +fi + +if [ $stage -le 4 ]; then + log "Stage 4: Enhance train segments using GSS (requires GPU)" + # for train, we use smaller context and larger batches to speed-up processing + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 5.0 \ + --use-garbage-class \ + --channels 0,1,2,3,4,5,6,7 \ + --min-segment-length 0.05 \ + --max-segment-length 25.0 \ + --max-batch-duration 60.0 \ + --num-buckets 4 \ + --num-workers 4 + done +fi + +if [ $stage -le 5 ]; then + log "Stage 5: Enhance eval/test segments using GSS (using GPU)" + # for eval/test, we use larger context and smaller batches to get better quality + for part in eval test; do + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 15.0 \ + --use-garbage-class \ + --channels 0,1,2,3,4,5,6,7 \ + --min-segment-length 0.05 \ + --max-segment-length 16.0 \ + --max-batch-duration 45.0 \ + --num-buckets 4 \ + --num-workers 4 + done + done +fi + +if [ $stage -le 6 ]; then + log "Stage 6: Prepare manifests for GSS-enhanced data" + python local/prepare_alimeeting_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05 +fi diff --git a/egs/alimeeting/ASR_v2/local/prepare_char.py b/egs/alimeeting/ASR_v2/local/prepare_char.py new file mode 120000 index 000000000..ee5dd34f1 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/prepare_char.py @@ -0,0 +1 @@ +../../ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/prepare_words.py b/egs/alimeeting/ASR_v2/local/prepare_words.py new file mode 120000 index 000000000..970bfd60c --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/prepare_words.py @@ -0,0 +1 @@ +../../ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/text2segments.py b/egs/alimeeting/ASR_v2/local/text2segments.py new file mode 120000 index 000000000..bf4547794 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/text2segments.py @@ -0,0 +1 @@ +../../ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/local/text2token.py b/egs/alimeeting/ASR_v2/local/text2token.py new file mode 120000 index 000000000..f6b8531b6 --- /dev/null +++ b/egs/alimeeting/ASR_v2/local/text2token.py @@ -0,0 +1 @@ +../../ASR/local/text2token.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh new file mode 100755 index 000000000..76a108771 --- /dev/null +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -0,0 +1,125 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=-1 +stop_stage=100 +use_gss=true # Use GSS-based enhancement with MDM setting + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/alimeeting +# This directory contains the following files downloaded from +# https://openslr.org/62/ +# +# - Train_Ali_far.tar.gz +# - Train_Ali_near.tar.gz +# - Test_Ali.tar.gz +# - Eval_Ali.tar.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then + lhotse download ali-meeting $dl_dir/alimeeting + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare alimeeting manifest" + # We assume that you have downloaded the alimeeting corpus + # to $dl_dir/alimeeting + for part in ihm sdm mdm; do + mkdir -p data/manifests/alimeeting + lhotse prepare ali-meeting --mic $part --save-mono --normalize-text m2met \ + $dl_dir/alimeeting data/manifests + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then + log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)" + # We assume that you have installed the GSS package: https://github.com/desh2608/gss + local/prepare_alimeeting_gss.sh data/manifests exp/alimeeting_gss +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + python local/compute_fbank_musan.py +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for alimeeting" + mkdir -p data/fbank + python local/compute_fbank_alimeeting.py + log "Combine features from train splits" + lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ + gzip -c > data/manifests/cuts_train_all.jsonl.gz +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare char based lang" + lang_char_dir=data/lang_char + mkdir -p $lang_char_dir + + # Prepare text. + # Note: in Linux, you can install jq with the following command: + # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + gunzip -c data/manifests/alimeeting-sdm_supervisions_train.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $lang_char_dir/text + + # Prepare words segments + python ./local/text2segments.py \ + --input $lang_char_dir/text \ + --output $lang_char_dir/text_words_segmentation + + cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \ + | sort -u | sed "/^$/d" \ + | uniq > $lang_char_dir/words_no_ids.txt + + # Prepare words.txt + if [ ! -f $lang_char_dir/words.txt ]; then + ./local/prepare_words.py \ + --input-file $lang_char_dir/words_no_ids.txt \ + --output-file $lang_char_dir/words.txt + fi + + if [ ! -f $lang_char_dir/L_disambig.pt ]; then + ./local/prepare_char.py + fi +fi diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py new file mode 100644 index 000000000..1cfd053c7 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1,419 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader +from tqdm import tqdm + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AlimeetingAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), + ) + group.add_argument( + "--max-duration", + type=int, + default=100.0, + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), + ) + group.add_argument( + "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch." + ) + group.add_argument( + "--num-buckets", + type=int, + default=50, + help=( + "The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets)." + ), + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help=( + "When enabled (=default), the examples will be " + "shuffled for each epoch." + ), + ) + + group.add_argument( + "--num-workers", + type=int, + default=8, + help=( + "The number of training dataloader workers that " "collect the batches." + ), + ) + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + "Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + if self.args.on_the_fly_feats: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + ) + else: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + ) + + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + max_cuts=self.args.max_cuts, + shuffle=False, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=True, + ) + sampler = DynamicBucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + def remove_short_cuts(self, cut: Cut) -> bool: + """ + See: https://github.com/k2-fsa/icefall/issues/500 + Basically, the zipformer model subsamples the input using the following formula: + num_out_frames = ((num_in_frames - 7)//2 + 1)//2 + For num_out_frames to be at least 1, num_in_frames must be at least 9. + """ + return cut.duration >= 0.09 + + @lru_cache() + def train_cuts(self, sp: Optional[Any] = None) -> CutSet: + logging.info("About to get AMI train cuts") + + def _remove_short_and_long_utt(c: Cut): + if c.duration < 0.1 or c.duration > 25.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = c.supervisions[0].text + return T >= len(tokens) + + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "cuts_train_all.jsonl.gz" + ) + + return cuts_train.filter(_remove_short_and_long_utt) + + @lru_cache() + def eval_ihm_cuts(self) -> CutSet: + logging.info("About to get AliMeeting IHM eval cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_ihm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def eval_sdm_cuts(self) -> CutSet: + logging.info("About to get AliMeeting SDM eval cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_sdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def eval_gss_cuts(self) -> CutSet: + if not (self.args.manifest_dir / "cuts_eval_gss.jsonl.gz").exists(): + logging.info("No GSS dev cuts found") + return None + logging.info("About to get AliMeeting GSS-enhanced eval cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_gss.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_ihm_cuts(self) -> CutSet: + logging.info("About to get AliMeeting IHM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_sdm_cuts(self) -> CutSet: + logging.info("About to get AliMeeting SDM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_gss_cuts(self) -> CutSet: + if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists(): + logging.info("No GSS test cuts found") + return None + logging.info("About to get AliMeeting GSS-enhanced test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz") + return cs.filter(self.remove_short_cuts) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..37516affc --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..2741e0eeb --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py @@ -0,0 +1,692 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 15 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method greedy_search + +(2) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 15 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 15 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AlimeetingAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_LG, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list(str(text).replace(" ", "")) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AlimeetingAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest_LG", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + if "fast_beam_search" in params.decoding_method: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + alimeeting = AlimeetingAsrDataModule(args) + + eval_ihm_cuts = alimeeting.eval_ihm_cuts() + test_ihm_cuts = alimeeting.test_ihm_cuts() + eval_sdm_cuts = alimeeting.eval_sdm_cuts() + test_sdm_cuts = alimeeting.test_sdm_cuts() + eval_gss_cuts = alimeeting.eval_gss_cuts() + test_gss_cuts = alimeeting.test_gss_cuts() + + eval_ihm_dl = alimeeting.test_dataloaders(eval_ihm_cuts) + test_ihm_dl = alimeeting.test_dataloaders(test_ihm_cuts) + eval_sdm_dl = alimeeting.test_dataloaders(eval_sdm_cuts) + test_sdm_dl = alimeeting.test_dataloaders(test_sdm_cuts) + if eval_gss_cuts is not None: + eval_gss_dl = alimeeting.test_dataloaders(eval_gss_cuts) + if test_gss_cuts is not None: + test_gss_dl = alimeeting.test_dataloaders(test_gss_cuts) + + test_sets = { + "eval_ihm": (eval_ihm_dl, eval_ihm_cuts), + "test_ihm": (test_ihm_dl, test_ihm_cuts), + "eval_sdm": (eval_sdm_dl, eval_sdm_cuts), + "test_sdm": (test_sdm_dl, test_sdm_cuts), + } + if eval_gss_cuts is not None: + test_sets["eval_gss"] = (eval_gss_dl, eval_gss_cuts) + if test_gss_cuts is not None: + test_sets["test_gss"] = (test_gss_dl, test_gss_cuts) + + for test_set in test_sets: + logging.info(f"Decoding {test_set}") + dl, cuts = test_sets[test_set] + results_dict = decode_dataset( + dl=dl, + params=params, + model=model, + lexicon=lexicon, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..0c2673d46 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py new file mode 100755 index 000000000..23a88dd29 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7/decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=15, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=8, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py new file mode 120000 index 000000000..a44034e34 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py new file mode 120000 index 000000000..068f0f57f --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py new file mode 120000 index 000000000..7ceac5d10 --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..757d6535e --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -0,0 +1,1186 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 15 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 150 \ + --use-fp16 True + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AlimeetingAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=15, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=5000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + + y = graph_compiler.texts_to_ids(texts) + if type(y) == list: + y = k2.RaggedTensor(y).to(device) + else: + y = y.to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ((feature_lens - 7) // 2).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + if params.inf_check: + register_inf_check_hooks(model) + + alimeeting = AlimeetingAsrDataModule(args) + + train_cuts = alimeeting.train_cuts() + train_dl = alimeeting.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = alimeeting.eval_ihm_cuts() + valid_dl = alimeeting.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # graph_compiler=graph_compiler, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AlimeetingAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/alimeeting/ASR_v2/shared b/egs/alimeeting/ASR_v2/shared new file mode 120000 index 000000000..3a3b28f96 --- /dev/null +++ b/egs/alimeeting/ASR_v2/shared @@ -0,0 +1 @@ +../../../egs/aishell/ASR/shared \ No newline at end of file diff --git a/egs/ami/ASR/README.md b/egs/ami/ASR/README.md new file mode 100644 index 000000000..1c9714bd4 --- /dev/null +++ b/egs/ami/ASR/README.md @@ -0,0 +1,48 @@ +# AMI + +This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's +headset and lapel microphones, and also 2 array microphones containing 8 channels each. +We pool data in the following 4 ways and train a single model on the pooled data: + +(i) individual headset microphone (IHM) +(ii) IHM with simulated reverb +(iii) Single distant microphone (SDM) +(iv) GSS-enhanced array microphones + +Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled +data. Here are the statistics of the combined training data: + +```python +>>> cuts_train.describe() +Cuts count: 1222053 +Total duration (hh:mm:ss): 905:00:28 +Speech duration (hh:mm:ss): 905:00:28 (99.9%) +Duration statistics (seconds): +mean 2.7 +std 2.8 +min 0.0 +25% 0.6 +50% 1.6 +75% 3.8 +99% 12.3 +99.5% 13.9 +99.9% 18.4 +max 36.8 +``` + +**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement +of far-field array microphones, but this is optional (see `prepare.sh` for details). + +## Performance Record + +### pruned_transducer_stateless7 + +The following are decoded using `modified_beam_search`: + +| Evaluation set | dev WER | test WER | +|--------------------------|------------|---------| +| IHM | 18.92 | 17.40 | +| SDM | 31.25 | 32.21 | +| MDM (GSS-enhanced) | 21.67 | 22.43 | + +See [RESULTS](/egs/ami/ASR/RESULTS.md) for details. diff --git a/egs/ami/ASR/RESULTS.md b/egs/ami/ASR/RESULTS.md new file mode 100644 index 000000000..163986021 --- /dev/null +++ b/egs/ami/ASR/RESULTS.md @@ -0,0 +1,92 @@ +## Results + +### AMI training results (Pruned Transducer) + +#### 2022-11-20 + +#### Zipformer (pruned_transducer_stateless7) + +Zipformer encoder + non-current decoder. The decoder +contains only an embedding layer, a Conv1d (with kernel size 2) and a linear +layer (to transform tensor dim). + +All the results below are using a single model that is trained by combining the following +data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise +augmentation are applied on top of the pooled data. + +**WERs for IHM:** + +| | dev | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 19.25 | 17.83 | --epoch 14 --avg 8 --max-duration 500 | +| modified beam search | 18.92 | 17.40 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 19.44 | 18.04 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +**WERs for SDM:** + +| | dev | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 31.32 | 32.38 | --epoch 14 --avg 8 --max-duration 500 | +| modified beam search | 31.25 | 32.21 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 31.11 | 32.10 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +**WERs for GSS-enhanced MDM:** + +| | dev | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 22.05 | 22.93 | --epoch 14 --avg 8 --max-duration 500 | +| modified beam search | 21.67 | 22.43 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 22.21 | 22.83 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 15 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 150 \ + --max-cuts 150 \ + --prune-range 5 \ + --lr-factor 5 \ + --lm-scale 0.25 \ + --use-fp16 True +``` + +The decoding command is: +``` +# greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 14 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method greedy_search + +# modified beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + +Pretrained model is available at + +The tensorboard training log can be found at + diff --git a/egs/ami/ASR/local/__init__.py b/egs/ami/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py new file mode 100755 index 000000000..4892b40e3 --- /dev/null +++ b/egs/ami/ASR/local/compute_fbank_ami.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the AMI dataset. +For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced +audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced +parts (which are the 3 evaluation settings). +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" +import logging +import math +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def compute_fbank_ami(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests_ihm = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix="ami-ihm", + suffix="jsonl.gz", + ) + manifests_sdm = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix="ami-sdm", + suffix="jsonl.gz", + ) + # For GSS we already have cuts so we read them directly. + manifests_gss = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix="ami-gss", + suffix="jsonl.gz", + ) + + def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None: + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) + _ = cuts.compute_and_store_features_batch( + extractor=extractor, + storage_path=storage_path, + manifest_path=manifest_path, + batch_duration=5000, + num_workers=8, + storage_type=LilcomChunkyWriter, + ) + + logging.info( + "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)" + ) + + logging.info("Processing train split IHM") + cuts_ihm = ( + CutSet.from_manifests(**manifests_ihm["train"]) + .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + .modify_ids(lambda x: x + "-ihm") + ) + _extract_feats( + cuts_ihm, + output_dir / "feats_train_ihm", + src_dir / "cuts_train_ihm.jsonl.gz", + ) + + logging.info("Processing train split IHM + reverberated IHM") + cuts_ihm_rvb = cuts_ihm.reverb_rir() + _extract_feats( + cuts_ihm_rvb, + output_dir / "feats_train_ihm_rvb", + src_dir / "cuts_train_ihm_rvb.jsonl.gz", + ) + + logging.info("Processing train split SDM") + cuts_sdm = ( + CutSet.from_manifests(**manifests_sdm["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-sdm") + ) + _extract_feats( + cuts_sdm, + output_dir / "feats_train_sdm", + src_dir / "cuts_train_sdm.jsonl.gz", + ) + + logging.info("Processing train split GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-gss") + ) + _extract_feats( + cuts_gss, + output_dir / "feats_train_gss", + src_dir / "cuts_train_gss.jsonl.gz", + ) + + logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") + for split in ["dev", "test"]: + logging.info(f"Processing {split} IHM") + cuts_ihm = ( + CutSet.from_manifests(**manifests_ihm[split]) + .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_ihm", + manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz", + batch_duration=5000, + num_workers=8, + storage_type=LilcomChunkyWriter, + ) + ) + logging.info(f"Processing {split} SDM") + cuts_sdm = ( + CutSet.from_manifests(**manifests_sdm[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_sdm", + manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + logging.info(f"Processing {split} GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_gss", + manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_ami() diff --git a/egs/ami/ASR/local/compute_fbank_musan.py b/egs/ami/ASR/local/compute_fbank_musan.py new file mode 100755 index 000000000..1fcf951f9 --- /dev/null +++ b/egs/ami/ASR/local/compute_fbank_musan.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the musan dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, LilcomChunkyWriter, combine +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_musan(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + dataset_parts = ( + "music", + "speech", + "noise", + ) + prefix = "musan" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + musan_cuts_path = src_dir / "musan_cuts.jsonl.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + # create chunks of Musan with duration 5 - 10 seconds + _ = ( + CutSet.from_manifests( + recordings=combine(part["recordings"] for part in manifests.values()) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / "musan_feats", + manifest_path=musan_cuts_path, + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_musan() diff --git a/egs/ami/ASR/local/prepare_ami_enhanced.py b/egs/ami/ASR/local/prepare_ami_enhanced.py new file mode 100644 index 000000000..bed220eb3 --- /dev/null +++ b/egs/ami/ASR/local/prepare_ami_enhanced.py @@ -0,0 +1,158 @@ +#!/usr/local/bin/python +# -*- coding: utf-8 -*- +# Data preparation for AMI GSS-enhanced dataset. + +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from lhotse import Recording, RecordingSet, SupervisionSet +from lhotse.qa import fix_manifests +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import fastcopy +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.") + parser.add_argument( + "manifests_dir", + type=Path, + help="Path to directory containing AMI manifests.", + ) + parser.add_argument( + "enhanced_dir", + type=Path, + help="Path to enhanced data directory.", + ) + parser.add_argument( + "--num-jobs", + "-j", + type=int, + default=1, + help="Number of parallel jobs to run.", + ) + parser.add_argument( + "--min-segment-duration", + "-d", + type=float, + default=0.0, + help="Minimum duration of a segment in seconds.", + ) + return parser.parse_args() + + +def find_recording_and_create_new_supervision(enhanced_dir, supervision): + """ + Given a supervision (corresponding to original AMI recording), this function finds the + enhanced recording correspoding to the supervision, and returns this recording and + a new supervision whose start and end times are adjusted to match the enhanced recording. + """ + file_name = Path( + f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac" + ) + save_path = enhanced_dir / f"{supervision.recording_id}" / file_name + if save_path.exists(): + recording = Recording.from_file(save_path) + if recording.duration == 0: + logging.warning(f"Skipping {save_path} which has duration 0 seconds.") + return None + + # Old supervision is wrt to the original recording, we create new supervision + # wrt to the enhanced segment + new_supervision = fastcopy( + supervision, + recording_id=recording.id, + start=0, + duration=recording.duration, + ) + return recording, new_supervision + else: + logging.warning(f"{save_path} does not exist.") + return None + + +def main(args): + # Get arguments + manifests_dir = args.manifests_dir + enhanced_dir = args.enhanced_dir + + # Load manifests from cache if they exist (saves time) + manifests = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=manifests_dir, + prefix="ami-sdm", + suffix="jsonl.gz", + ) + if not manifests: + raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir)) + + with ThreadPoolExecutor(args.num_jobs) as ex: + for part in ["train", "dev", "test"]: + logging.info(f"Processing {part}...") + supervisions_orig = manifests[part]["supervisions"].filter( + lambda s: s.duration >= args.min_segment_duration + ) + # Remove TS3009d supervisions since they are not present in the enhanced data + supervisions_orig = supervisions_orig.filter( + lambda s: s.recording_id != "TS3009d" + ) + futures = [] + + for supervision in tqdm( + supervisions_orig, + desc="Distributing tasks", + ): + futures.append( + ex.submit( + find_recording_and_create_new_supervision, + enhanced_dir, + supervision, + ) + ) + + recordings = [] + supervisions = [] + for future in tqdm( + futures, + total=len(futures), + desc="Processing tasks", + ): + result = future.result() + if result is not None: + recording, new_supervision = result + recordings.append(recording) + supervisions.append(new_supervision) + + # Remove duplicates from the recordings + recordings_nodup = {} + for recording in recordings: + if recording.id not in recordings_nodup: + recordings_nodup[recording.id] = recording + else: + logging.warning("Recording {} is duplicated.".format(recording.id)) + recordings = RecordingSet.from_recordings(recordings_nodup.values()) + supervisions = SupervisionSet.from_segments(supervisions) + + recordings, supervisions = fix_manifests( + recordings=recordings, supervisions=supervisions + ) + + logging.info(f"Writing {part} enhanced manifests") + recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz") + supervisions.to_file( + manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz" + ) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh new file mode 100755 index 000000000..d5422458b --- /dev/null +++ b/egs/ami/ASR/local/prepare_ami_gss.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# This script is used to run GSS-based enhancement on AMI data. +set -euo pipefail +nj=4 +stage=0 + +. shared/parse_options.sh || exit 1 + +if [ $# != 2 ]; then + echo "Wrong #arguments ($#, expected 2)" + echo "Usage: local/prepare_ami_gss.sh [options] " + echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss" + echo "main options (for others, see top of script file)" + echo " --nj # number of parallel jobs" + echo " --stage # stage to start running from" + exit 1; +fi + +DATA_DIR=$1 +EXP_DIR=$2 + +mkdir -p $EXP_DIR + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 1 ]; then + log "Stage 1: Prepare cut sets" + for part in train dev test; do + lhotse cut simple \ + -r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \ + -s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \ + $EXP_DIR/cuts_${part}.jsonl.gz + done +fi + +if [ $stage -le 2 ]; then + log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" + for part in train dev test; do + lhotse cut trim-to-supervisions --discard-overlapping \ + $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz + done +fi + +if [ $stage -le 3 ]; then + log "Stage 3: Split manifests for multi-GPU processing (optional)" + for part in train; do + gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj + done +fi + +if [ $stage -le 4 ]; then + log "Stage 4: Enhance train segments using GSS (requires GPU)" + # for train, we use smaller context and larger batches to speed-up processing + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 5.0 \ + --use-garbage-class \ + --channels 0,1,2,3,4,5,6,7 \ + --min-segment-length 0.05 \ + --max-segment-length 35.0 \ + --max-batch-duration 60.0 \ + --num-buckets 3 \ + --num-workers 2 + done +fi + +if [ $stage -le 5 ]; then + log "Stage 5: Enhance dev/test segments using GSS (using GPU)" + # for dev/test, we use larger context and smaller batches to get better quality + for part in dev test; do + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 15.0 \ + --use-garbage-class \ + --channels 0,1,2,3,4,5,6,7 \ + --min-segment-length 0.05 \ + --max-segment-length 30.0 \ + --max-batch-duration 45.0 \ + --num-buckets 3 \ + --num-workers 2 + done + done +fi + +if [ $stage -le 6 ]; then + log "Stage 6: Prepare manifests for GSS-enhanced data" + python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05 +fi diff --git a/egs/ami/ASR/local/prepare_lang_bpe.py b/egs/ami/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/ami/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/ami/ASR/local/train_bpe_model.py b/egs/ami/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/ami/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ami/ASR/prepare.sh b/egs/ami/ASR/prepare.sh new file mode 100755 index 000000000..fb21a8ec6 --- /dev/null +++ b/egs/ami/ASR/prepare.sh @@ -0,0 +1,144 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=-1 +stop_stage=100 +use_gss=true # Use GSS-based enhancement with MDM setting + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/amicorpus +# You can find audio and transcripts in this path. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +# +# - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19} +# These contain the Fisher English audio and transcripts. We will +# only use the transcripts as extra LM training data (similar to Kaldi). +# +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +vocab_size=500 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/amicorpus, + # you can create a symlink + # + # ln -sfv /path/to/amicorpus $dl_dir/amicorpus + # + if [ ! -d $dl_dir/amicorpus ]; then + lhotse download ami --mic ihm $dl_dir/amicorpus + lhotse download ami --mic mdm $dl_dir/amicorpus + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare AMI manifests" + # We assume that you have downloaded the AMI corpus + # to $dl_dir/amicorpus. We perform text normalization for the transcripts. + mkdir -p data/manifests + for mic in ihm sdm mdm; do + lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \ + --max-words-per-segment 30 $dl_dir/amicorpus data/manifests/ + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then + log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)" + # We assume that you have installed the GSS package: https://github.com/desh2608/gss + local/prepare_ami_gss.sh data/manifests exp/ami_gss +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank features for AMI" + mkdir -p data/fbank + python local/compute_fbank_ami.py + log "Combine features from train splits" + lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ + gzip -c > data/manifests/cuts_train_all.jsonl.gz +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank features for musan" + mkdir -p data/fbank + python local/compute_fbank_musan.py +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Dump transcripts for BPE model training." + mkdir -p data/lm + cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare BPE based lang" + + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + # Add special words to words.txt + echo " 0" > $lang_dir/words.txt + echo "!SIL 1" >> $lang_dir/words.txt + echo " 2" >> $lang_dir/words.txt + + # Add regular words to words.txt + cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt + + # Add remaining special word symbols expected by LM scripts. + num_words=$(cat $lang_dir/words.txt | wc -l) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(cat $lang_dir/words.txt | wc -l) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(cat $lang_dir/words.txt | wc -l) + echo "#0 ${num_words}" >> $lang_dir/words.txt + + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript data/lm/transcript_words.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi +fi diff --git a/egs/ami/ASR/pruned_transducer_stateless7/__init__.py b/egs/ami/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py new file mode 100644 index 000000000..f7ee9c962 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1,430 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader +from tqdm import tqdm + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AmiAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), + ) + group.add_argument( + "--max-duration", + type=int, + default=100.0, + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), + ) + group.add_argument( + "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch." + ) + group.add_argument( + "--num-buckets", + type=int, + default=50, + help=( + "The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets)." + ), + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help=( + "When enabled (=default), the examples will be " + "shuffled for each epoch." + ), + ) + + group.add_argument( + "--num-workers", + type=int, + default=8, + help=( + "The number of training dataloader workers that " "collect the batches." + ), + ) + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), + ) + group.add_argument( + "--ihm-only", + type=str2bool, + default=False, + help="When enabled, only use IHM data for training.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + "Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + if self.args.on_the_fly_feats: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + ) + else: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + ) + + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + max_cuts=self.args.max_cuts, + shuffle=False, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=True, + ) + sampler = DynamicBucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + def remove_short_cuts(self, cut: Cut) -> bool: + """ + See: https://github.com/k2-fsa/icefall/issues/500 + Basically, the zipformer model subsamples the input using the following formula: + num_out_frames = (num_in_frames - 7)//2 + For num_out_frames to be at least 1, num_in_frames must be at least 9. + """ + return cut.duration >= 0.09 + + @lru_cache() + def train_cuts(self, sp: Optional[Any] = None) -> CutSet: + logging.info("About to get AMI train cuts") + + def _remove_short_and_long_utt(c: Cut): + if c.duration < 0.2 or c.duration > 25.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + return T >= len(tokens) + + if self.args.ihm_only: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "cuts_train_ihm.jsonl.gz" + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "cuts_train_all.jsonl.gz" + ) + + return cuts_train.filter(_remove_short_and_long_utt) + + @lru_cache() + def dev_ihm_cuts(self) -> CutSet: + logging.info("About to get AMI IHM dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def dev_sdm_cuts(self) -> CutSet: + logging.info("About to get AMI SDM dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def dev_gss_cuts(self) -> CutSet: + if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists(): + logging.info("No GSS dev cuts found") + return None + logging.info("About to get AMI GSS-enhanced dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_ihm_cuts(self) -> CutSet: + logging.info("About to get AMI IHM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_sdm_cuts(self) -> CutSet: + logging.info("About to get AMI SDM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_gss_cuts(self) -> CutSet: + if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists(): + logging.info("No GSS test cuts found") + return None + logging.info("About to get AMI GSS-enhanced test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz") + return cs.filter(self.remove_short_cuts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..37516affc --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..9999894d1 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -0,0 +1,739 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_LG, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + word_table: + The word symbol table. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" + with open(wers_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + # we also compute CER for AMI dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" + with open(cers_filename, "w") as f: + cer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(wers_filename)) + + test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} + test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER\tCER", file=f) + for key in test_set_wers: + print( + "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]), + file=f, + ) + + s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key in test_set_wers: + s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest_LG", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(f"{params.lang_dir}/bpe.model") + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + ami = AmiAsrDataModule(args) + + dev_ihm_cuts = ami.dev_ihm_cuts() + test_ihm_cuts = ami.test_ihm_cuts() + dev_sdm_cuts = ami.dev_sdm_cuts() + test_sdm_cuts = ami.test_sdm_cuts() + dev_gss_cuts = ami.dev_gss_cuts() + test_gss_cuts = ami.test_gss_cuts() + + dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts) + test_ihm_dl = ami.test_dataloaders(test_ihm_cuts) + dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts) + test_sdm_dl = ami.test_dataloaders(test_sdm_cuts) + if dev_gss_cuts is not None: + dev_gss_dl = ami.test_dataloaders(dev_gss_cuts) + if test_gss_cuts is not None: + test_gss_dl = ami.test_dataloaders(test_gss_cuts) + + test_sets = { + "dev_ihm": (dev_ihm_dl, dev_ihm_cuts), + "test_ihm": (test_ihm_dl, test_ihm_cuts), + "dev_sdm": (dev_sdm_dl, dev_sdm_cuts), + "test_sdm": (test_sdm_dl, test_sdm_cuts), + } + if dev_gss_cuts is not None: + test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts) + if test_gss_cuts is not None: + test_sets["test_gss"] = (test_gss_dl, test_gss_cuts) + + for test_set in test_sets: + logging.info(f"Decoding {test_set}") + dl, cuts = test_sets[test_set] + results_dict = decode_dataset( + dl=dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..0c2673d46 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/export.py b/egs/ami/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/model.py b/egs/ami/ASR/pruned_transducer_stateless7/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/optim.py b/egs/ami/ASR/pruned_transducer_stateless7/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..81823ced2 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -0,0 +1,1193 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 15 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 150 \ + --use-fp16 True + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=11, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=5000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"] + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = supervisions["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ((feature_lens - 7) // 2).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + if params.inf_check: + register_inf_check_hooks(model) + + ami = AmiAsrDataModule(args) + + # Here is the duration statistics of the training set. + # Cuts count: 1230033 + # Total duration (hh:mm:ss): 904:25:34 + # Speech duration (hh:mm:ss): 904:25:34 (100.0%) + # Duration statistics (seconds): + # mean 2.6 + # std 2.8 + # min 0.0 + # 25% 0.6 + # 50% 1.6 + # 75% 3.8 + # 99% 12.3 + # 99.5% 13.9 + # 99.9% 18.3 + # max 36.8 + + train_cuts = ami.train_cuts(sp=sp) + train_dl = ami.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) + + valid_cuts = ami.dev_ihm_cuts() + valid_dl = ami.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/ami/ASR/shared b/egs/ami/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/ami/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/commonvoice/ASR/README.md b/egs/commonvoice/ASR/README.md new file mode 100644 index 000000000..a4582499b --- /dev/null +++ b/egs/commonvoice/ASR/README.md @@ -0,0 +1,18 @@ +# Introduction + +This recipe includes some different ASR models trained with Common Voice + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan | + +The decoder in `transducer_stateless` is modified from the paper +[RNN-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md new file mode 100644 index 000000000..751625371 --- /dev/null +++ b/egs/commonvoice/ASR/RESULTS.md @@ -0,0 +1,59 @@ +## Results +### GigaSpeech BPE training results (Pruned Stateless Transducer 7) + +#### [pruned_transducer_stateless7](./pruned_transducer_stateless7) + +See #997 for more details. + +Number of model parameters: 70369391, i.e., 70.37 M + +The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below: + +Results are: + +| | Dev | Test | +|----------------------|-------|-------| +| greedy search | 9.96 | 12.54 | +| modified beam search | 9.86 | 12.48 | + +To reproduce the above result, use the following commands for training: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 550 +``` + +and the following commands for decoding: + +```bash +# greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 5 \ + --decoding-method greedy_search \ + --exp-dir pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --max-duration 600 + +# modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 5 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --max-duration 600 +``` + +Pretrained model is available at + + +The tensorboard log for training is available at + diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py new file mode 100755 index 000000000..c8f9b6ccb --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the CommonVoice dataset. +It looks for manifests in the directory data/${lang}/manifests. + +The generated fbank features are saved in data/${lang}/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + return parser.parse_args() + + +def compute_fbank_commonvoice_dev_test(language: str): + src_dir = Path(f"data/{language}/manifests") + output_dir = Path(f"data/{language}/fbank") + num_workers = 42 + batch_duration = 600 + + subsets = ("dev", "test") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + for partition in subsets: + cuts_path = output_dir / f"cv-{language}_cuts_{partition}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + + raw_cuts_path = output_dir / f"cv-{language}_cuts_{partition}_raw.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/cv-{language}_feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_commonvoice_dev_test(language=args.language) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py new file mode 100755 index 000000000..0564f6ec6 --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + set_audio_duration_mismatch_tolerance, + set_caching_enabled, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the train subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + + return parser.parse_args() + + +def compute_fbank_commonvoice_splits(args): + subset = "train" + num_splits = args.num_splits + language = args.language + output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = len(str(num_splits)) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_caching_enabled(False) + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = output_dir / f"cv-{language}_cuts_{subset}_raw.{idx}.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/cv-{language}_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_commonvoice_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/local/compute_fbank_musan.py b/egs/commonvoice/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/filter_cuts.py b/egs/commonvoice/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/commonvoice/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/prepare_lang_bpe.py b/egs/commonvoice/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/commonvoice/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py new file mode 100755 index 000000000..c5ec14502 --- /dev/null +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import re +from pathlib import Path +from typing import Optional + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + return parser.parse_args() + + +def normalize_text(utt: str) -> str: + utt = re.sub(r"[{0}]+".format("-"), " ", utt) + return re.sub(r"[^a-zA-Z\s]", "", utt).upper() + + +def preprocess_commonvoice( + language: str, + dataset: Optional[str] = None, +): + src_dir = Path(f"data/{language}/manifests") + output_dir = Path(f"data/{language}/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "dev", + "test", + "train", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest") + prefix = f"cv-{language}" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + text = str(sup.text) + orig_text = text + sup.text = normalize_text(sup.text) + text = str(sup.text) + if len(orig_text) != len(text): + logging.info( + f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" + ) + + # Create long-recording cut manifests. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Run data augmentation that needs to be done in the + # time domain. + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_commonvoice( + language=args.language, + dataset=args.dataset, + ) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/local/train_bpe_model.py b/egs/commonvoice/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/commonvoice/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/validate_bpe_lexicon.py b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh new file mode 100755 index 000000000..7a583f9c8 --- /dev/null +++ b/egs/commonvoice/ASR/prepare.sh @@ -0,0 +1,244 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# Split data/${lang}set to this number of pieces +# This is to avoid OOM during feature extraction. +num_splits=1000 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/$release/$lang +# This directory contains the following files downloaded from +# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz +# +# - clips +# - dev.tsv +# - invalidated.tsv +# - other.tsv +# - reported.tsv +# - test.tsv +# - train.tsv +# - validated.tsv +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download +release=cv-corpus-13.0-2023-03-09 +lang=en + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/${lang}/lang_bpe_xxx, +# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/$release/$lang/clips ]; then + lhotse download commonvoice --languages $lang --release $release $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare CommonVoice manifest" + # We assume that you have downloaded the CommonVoice corpus + # to $dl_dir/$release + mkdir -p data/${lang}/manifests + if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then + lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess CommonVoice manifest" + if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then + ./local/preprocess_commonvoice.py --language $lang + touch data/${lang}/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for dev and test subsets of CommonVoice" + mkdir -p data/${lang}/fbank + if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then + ./local/compute_fbank_commonvoice_dev_test.py --language $lang + touch data/${lang}/fbank/.cv-${lang}_dev_test.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split train subset into ${num_splits} pieces" + split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits} + if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_train_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compute features for train subset of CommonVoice" + if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then + ./local/compute_fbank_commonvoice_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang + touch data/${lang}/fbank/.cv-${lang}_train.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Combine features for train" + if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then + pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz") + lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/${lang}/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz" + ) + gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py new file mode 100644 index 000000000..2c37244a4 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1,420 @@ +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class CommonVoiceAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. CommonVoice test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--language", + type=str, + default="en", + help="""Language of Common Voice""", + ) + group.add_argument( + "--cv-manifest-dir", + type=Path, + default=Path("data/en/fbank"), + help="Path to directory with CommonVoice train/dev/test cuts.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with the other cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz" + ) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..52b2fbcab --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py @@ -0,0 +1,962 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(8) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/en/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + commonvoice = CommonVoiceAsrDataModule(args) + + dev_cuts = commonvoice.dev_cuts() + test_cuts = commonvoice.test_cuts() + + dev_dl = commonvoice.valid_dataloaders(dev_cuts) + test_dl = commonvoice.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py new file mode 100755 index 000000000..0c98885ac --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py new file mode 100755 index 000000000..53705321e --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 5 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 5 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/commonvoice/ASR + ./pruned_transducer_stateless7/decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/en/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 + # You will find the pre-trained model in icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 100755 index 000000000..19c518eaf --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./pruned_transducer_stateless7/export.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - cpu_jit.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx +""" + +import argparse +import logging + +from icefall import is_module_available +from onnx_pretrained import OnnxModel + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") + + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T + + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) + + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) + + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + test_encoder(torch_model, onnx_model) + + logging.info("Test decoder") + test_decoder(torch_model, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_model, onnx_model) + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 100755 index 000000000..eee19191e --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +We use the pre-trained model from +https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/en/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/en/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +""" + +import argparse +import logging +import math +from typing import List, Tuple + +import k2 +import kaldifeat +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, joiner_dim) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.run_decoder(decoder_input) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) + + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = model.run_decoder(decoder_input) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += symbol_table[i] + return text.replace("▁", " ").strip() + + for filename, hyp in zip(args.sound_files, hyps): + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py new file mode 100755 index 000000000..a22d1b4ba --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/en/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 5 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/en/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by +./pruned_transducer_stateless7/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..73a29a90a --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -0,0 +1,1250 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CommonVoiceAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/en/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + commonvoice = CommonVoiceAsrDataModule(args) + + train_cuts = commonvoice.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = commonvoice.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = commonvoice.dev_cuts() + valid_dl = commonvoice.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CommonVoiceAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/shared b/egs/commonvoice/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/commonvoice/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore index 5d965832e..cd0e20c4c 100644 --- a/egs/csj/ASR/.gitignore +++ b/egs/csj/ASR/.gitignore @@ -5,4 +5,4 @@ notify_tg.py finetune_* misc.ini .vscode/* -offline/* \ No newline at end of file +offline/* diff --git a/egs/csj/ASR/README.md b/egs/csj/ASR/README.md new file mode 100644 index 000000000..95c2ec6ac --- /dev/null +++ b/egs/csj/ASR/README.md @@ -0,0 +1,11 @@ +# Introduction + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +These are the types of architectures currently available. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming | diff --git a/egs/csj/ASR/RESULTS.md b/egs/csj/ASR/RESULTS.md new file mode 100644 index 000000000..56fdb899f --- /dev/null +++ b/egs/csj/ASR/RESULTS.md @@ -0,0 +1,200 @@ +# Results + +## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +Number of model parameters: 75688409, i.e. 75.7M. + +#### training on disfluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise | +| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode disfluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 0 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 2 \ + --num-decode-streams 40 + done +done +``` + +#### training on fluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise | +| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode fluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 1 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 3 \ + --num-decode-streams 40 + done +done +``` + +#### Comparing disfluent to fluent + +$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$ + +This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared. + +| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode | +| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- | +| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming | +| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise | +| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming | +| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise | +| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming | +| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise | +| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming | +| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise | +| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming | +| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise | +| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming | +| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise | +| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | | diff --git a/egs/csj/ASR/local/add_transcript_mode.py b/egs/csj/ASR/local/add_transcript_mode.py new file mode 100644 index 000000000..f6b4b2caf --- /dev/null +++ b/egs/csj/ASR/local/add_transcript_mode.py @@ -0,0 +1,94 @@ +import argparse +import logging +from configparser import ConfigParser +from pathlib import Path +from typing import List + +from lhotse import CutSet, SupervisionSet +from lhotse.recipes.csj import CSJSDBParser + +ARGPARSE_DESCRIPTION = """ +This script adds transcript modes to an existing CutSet or SupervisionSet. +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "-f", + "--fbank-dir", + type=Path, + help="Path to directory where manifests are stored.", + ) + parser.add_argument( + "-c", + "--config", + type=Path, + nargs="+", + help="Path to config file for transcript parsing.", + ) + return parser.parse_args() + + +def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]: + parsers = [] + for config_file in config_files: + config = ConfigParser() + config.optionxform = str + assert config.read(config_file), f"{config_file} could not be found." + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + parsers.append( + (config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions)) + ) + return parsers + + +def main(): + args = get_args() + logging.basicConfig( + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), + level=logging.INFO, + ) + parsers = get_CSJParsers(args.config) + config = ConfigParser() + config.optionxform = str + assert config.read(args.config), args.config + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + + logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.") + + manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz") + assert manifests, f"No cuts to be found in {args.fbank_dir}" + + for manifest in manifests: + results = [] + logging.info(f"Adding transcript modes to {manifest.name} now.") + cutset = CutSet.from_file(manifest) + for cut in cutset: + for name, parser in parsers: + cut.supervisions[0].custom[name] = parser.parse( + cut.supervisions[0].custom["raw"] + ) + cut.supervisions[0].text = "" + results.append(cut) + results = CutSet.from_items(results) + res_file = manifest.as_posix() + manifest.replace(manifest.parent / ("bak." + manifest.name)) + results.to_file(res_file) + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 994dedbdd..ce560025d 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,36 +19,23 @@ import argparse import logging import os -from itertools import islice from pathlib import Path -from random import Random from typing import List, Tuple import torch -from lhotse import ( + +# fmt: off +from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527 CutSet, Fbank, FbankConfig, - # fmt: off - # See the following for why LilcomChunkyWriter is preferred - # https://github.com/k2-fsa/icefall/pull/404 - # https://github.com/lhotse-speech/lhotse/pull/527 - # fmt: on LilcomChunkyWriter, RecordingSet, SupervisionSet, ) +from lhotse.recipes.csj import concat_csj_supervisions -ARGPARSE_DESCRIPTION = """ -This script follows the espnet method of splitting the remaining core+noncore -utterances into valid and train cutsets at an index which is by default 4000. - -In other words, the core+noncore utterances are shuffled, where 4000 utterances -of the shuffled set go to the `valid` cutset and are not subject to speed -perturbation. The remaining utterances become the `train` cutset and are speed- -perturbed (0.9x, 1.0x, 1.1x). - -""" +# fmt: on # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -58,78 +45,100 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) RNG_SEED = 42 +# concat_params_train = [ +# {"gap": 1.0, "maxlen": 10.0}, +# {"gap": 1.5, "maxlen": 8.0}, +# {"gap": 1.0, "maxlen": 18.0}, +# ] + +concat_params = {"gap": 1.0, "maxlen": 10.0} def make_cutset_blueprints( manifest_dir: Path, - split: int, ) -> List[Tuple[str, CutSet]]: cut_sets = [] + logging.info("Creating non-train cuts.") + # Create eval datasets - logging.info("Creating eval cuts.") for i in range(1, 4): + sps = sorted( + SupervisionSet.from_file( + manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" + ), + key=lambda x: x.id, + ) + cut_set = CutSet.from_manifests( recordings=RecordingSet.from_file( manifest_dir / f"csj_recordings_eval{i}.jsonl.gz" ), - supervisions=SupervisionSet.from_file( - manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" - ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) cut_sets.append((f"eval{i}", cut_set)) - # Create train and valid cuts - logging.info( - "Loading, trimming, and shuffling the remaining core+noncore cuts." + # Create excluded dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"), + key=lambda x: x.id, ) - recording_set = RecordingSet.from_file( - manifest_dir / "csj_recordings_core.jsonl.gz" - ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") - supervision_set = SupervisionSet.from_file( - manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file( - manifest_dir / "csj_supervisions_noncore.jsonl.gz" - ) - cut_set = CutSet.from_manifests( - recordings=recording_set, - supervisions=supervision_set, + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_excluded.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_set = cut_set.shuffle(Random(RNG_SEED)) + cut_sets.append(("excluded", cut_set)) - logging.info( - "Creating valid and train cuts from core and noncore," - f"split at {split}." + # Create valid dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"), + key=lambda x: x.id, ) - valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) + cut_set = CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_valid.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_sets.append(("valid", cut_set)) - train_set = CutSet.from_cuts(islice(cut_set, split, None)) - train_set = ( - train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) + logging.info("Creating train cuts.") + + # Create train dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz") + + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"), + key=lambda x: x.id, ) - cut_sets.extend([("valid", valid_set), ("train", train_set)]) + recording = RecordingSet.from_file( + manifest_dir / "csj_recordings_core.jsonl.gz" + ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") + + train_set = CutSet.from_manifests( + recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params) + ).trim_to_supervisions(keep_overlapping=False) + train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) + + cut_sets.append(("train", train_set)) return cut_sets def get_args(): parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" + "-m", "--manifest-dir", type=Path, help="Path to save manifests" ) parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" - ) - parser.add_argument( - "--split", type=int, default=4000, help="Split at this index" + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" ) return parser.parse_args() @@ -141,9 +150,7 @@ def main(): extractor = Fbank(FbankConfig(num_mel_bins=80)) num_jobs = min(16, os.cpu_count()) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) @@ -154,7 +161,7 @@ def main(): ) return else: - cut_sets = make_cutset_blueprints(args.manifest_dir, args.split) + cut_sets = make_cutset_blueprints(args.manifest_dir) for part, cut_set in cut_sets: logging.info(f"Processing {part}") cut_set = cut_set.compute_and_store_features( @@ -163,7 +170,7 @@ def main(): storage_path=(args.fbank_dir / f"feats_{part}").as_posix(), storage_type=LilcomChunkyWriter, ) - cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz") + cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz") logging.info("All fbank computed for CSJ.") (args.fbank_dir / ".done").touch() diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index 44a33c4eb..c942df98e 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -26,12 +26,9 @@ from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor - ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. -The generated fbank features are saved in data/fbank. """ # Torch's multithreaded behavior needs to be disabled or @@ -43,8 +40,6 @@ torch.set_num_interop_threads(1) def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): - # src_dir = Path("data/manifests") - # output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 @@ -84,9 +79,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -108,10 +101,10 @@ def get_args(): ) parser.add_argument( - "--manifest-dir", type=Path, help="Path to save manifests" + "-m", "--manifest-dir", type=Path, help="Path to save manifests" ) parser.add_argument( - "--fbank-dir", type=Path, help="Path to save fbank features" + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" ) return parser.parse_args() @@ -119,9 +112,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan(args.manifest_dir, args.fbank_dir) diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index eb70673de..4f0a9ec0e 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,321 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = disfluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 0 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 0 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 0 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 0 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 0 -; # Example: '(X (D2 ノ))' -D2^ = 0 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index 5d22f9eb8..5d033ed17 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,321 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = fluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index 2613c3409..3ada9aa24 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,321 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = number -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) A_num = 1 -A_num^ = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo - diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index 8ba451dd5..dafd65c9a 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,322 +1,80 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode -; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf +; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf MODE = symbol -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' -F = # -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = # +F = "#", ["F"] ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = @ -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = @ +D = "@", ["D"] ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = @ -; # Example: '(X (D2 ノ))' -D2^ = @ +D2 = "@", ["D2"] ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo - diff --git a/egs/csj/ASR/local/disfluent_recogs_to_fluent.py b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py new file mode 100644 index 000000000..45c9c7656 --- /dev/null +++ b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py @@ -0,0 +1,202 @@ +import argparse +from pathlib import Path + +import kaldialign +from lhotse import CutSet + +ARGPARSE_DESCRIPTION = """ +This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript, +compares it against a fluent transcript, and saves the results in a separate directory. +This is useful to compare disfluent models with fluent models on the same metric. + +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "--recogs", + type=Path, + required=True, + help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.", + ) + parser.add_argument( + "--cut", + type=Path, + required=True, + help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.", + ) + parser.add_argument( + "--res-dir", type=Path, required=True, help="Path to save results" + ) + return parser.parse_args() + + +def d2f(stats): + """ + Compare the outputs of a disfluent model against a fluent reference. + Indicates a disfluent model's performance only on the content words + + CER^d_f = (sub_f + ins + del_f) / Nf + + """ + return stats["base"] / stats["Nf"] + + +def calc_cer(refs, hyps): + subs = { + "F": 0, + "D": 0, + } + ins = 0 + dels = { + "F": 0, + "D": 0, + } + cors = { + "F": 0, + "D": 0, + } + dis_ref_len = 0 + flu_ref_len = 0 + + for ref, hyp in zip(refs, hyps): + assert ( + ref[0] == hyp[0] + ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}." + tag = ref[2].copy() + ref = ref[1] + dis_ref_len += len(ref) + # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively. + flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)]) + hyp = hyp[1] + ali = kaldialign.align(ref, hyp, "*") + tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali] + for tag, (ref_word, hyp_word) in zip(tags, ali): + if "D" in tag or "F" in tag: + tag = "D" + else: + tag = "F" + + if ref_word == "*": + ins += 1 + elif hyp_word == "*": + dels[tag] += 1 + elif ref_word != hyp_word: + subs[tag] += 1 + else: + cors[tag] += 1 + + return { + "subs": subs, + "ins": ins, + "dels": dels, + "cors": cors, + "dis_ref_len": dis_ref_len, + "flu_ref_len": flu_ref_len, + } + + +def for_each_recogs(recogs_file: Path, refs, out_dir): + hyps = [] + with recogs_file.open() as fin: + for line in fin: + if "ref" in line: + continue + cutid, hyp = line.split(":\thyp=") + hyps.append((cutid, eval(hyp))) + + assert len(refs) == len( + hyps + ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal." + stats = calc_cer(refs, hyps) + stat_table = ["tag,yes,no"] + + for cer_type in ["subs", "dels", "cors", "ins"]: + ret = f"{cer_type}" + for df in ["D", "F"]: + try: + ret += f",{stats[cer_type][df]}" + except TypeError: + # insertions do not belong to F or D, and is not subscriptable. + ret += f",{stats[cer_type]}," + break + stat_table.append(ret) + stat_table = "\n".join(stat_table) + + stats = { + "subd": stats["subs"]["D"], + "deld": stats["dels"]["D"], + "cord": stats["cors"]["D"], + "Nf": stats["flu_ref_len"], + "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"], + } + + cer = d2f(stats) + results = [ + f"{cer:.2%}", + f"Nf,{stats['Nf']}", + ] + results = "\n".join(results) + + with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout: + fout.write(results) + fout.write("\n\n") + fout.write(stat_table) + + +def main(): + args = get_args() + recogs_file: Path = args.recogs + assert ( + recogs_file.is_file() or recogs_file.is_dir() + ), f"recogs_file cannot be found at {recogs_file}." + + args.res_dir.mkdir(parents=True, exist_ok=True) + + if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"): + assert ( + "csj_cuts" in args.cut.name + ), f"Expected {args.cut} to be a cuts manifest." + + refs: CutSet = CutSet.from_file(args.cut) + refs = sorted( + [ + ( + e.id, + list(e.supervisions[0].custom["disfluent"]), + e.supervisions[0].custom["disfluent_tag"].split(","), + ) + for e in refs + ], + key=lambda x: x[0], + ) + for_each_recogs(recogs_file, refs, args.res_dir) + + elif recogs_file.is_dir(): + recogs_file_path = recogs_file + for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]: + refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz") + refs = sorted( + [ + ( + r.id, + list(r.supervisions[0].custom["disfluent"]), + r.supervisions[0].custom["disfluent_tag"].split(","), + ) + for r in refs + ], + key=lambda x: x[0], + ) + for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"): + for_each_recogs(recogs_file, refs, args.res_dir) + + else: + raise TypeError(f"Unrecognised recogs file provided: {recogs_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c9de21073..924474d33 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,9 +37,7 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument( - "--manifest-dir", type=Path, help="Path to cutset manifests" - ) + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") return parser.parse_args() @@ -47,8 +45,8 @@ def get_parser(): def main(): args = get_parser() - for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"): - + for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]: + path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz" cuts: CutSet = load_manifest(path) print("\n---------------------------------\n") @@ -60,123 +58,271 @@ if __name__ == "__main__": main() """ -## eval1 -Cuts count: 1272 -Total duration (hh:mm:ss): 01:50:07 -Speech duration (hh:mm:ss): 01:50:07 (100.0%) -Duration statistics (seconds): -mean 5.2 -std 3.9 -min 0.2 -25% 1.9 -50% 4.0 -75% 8.1 -99% 14.3 -99.5% 14.7 -99.9% 16.0 -max 16.9 -Recordings available: 1272 -Features available: 1272 -Supervisions available: 1272 +csj_cuts_eval1.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:55:40 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.8 │ +├───────────────────────────┼──────────┤ +│ std │ 2.7 │ +├───────────────────────────┼──────────┤ +│ min │ 0.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.7 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1023 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- fluent (in 1272 cuts) -- disfluent (in 1272 cuts) -- number (in 1272 cuts) -- symbol (in 1272 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## eval2 -Cuts count: 1292 -Total duration (hh:mm:ss): 01:56:50 -Speech duration (hh:mm:ss): 01:56:50 (100.0%) -Duration statistics (seconds): -mean 5.4 -std 3.9 -min 0.1 -25% 2.1 -50% 4.6 -75% 8.6 -99% 14.1 -99.5% 15.2 -99.9% 16.1 -max 16.9 -Recordings available: 1292 -Features available: 1292 -Supervisions available: 1292 -SUPERVISION custom fields: -- fluent (in 1292 cuts) -- number (in 1292 cuts) -- symbol (in 1292 cuts) -- disfluent (in 1292 cuts) +--------------------------------- -## eval3 -Cuts count: 1385 -Total duration (hh:mm:ss): 01:19:21 -Speech duration (hh:mm:ss): 01:19:21 (100.0%) -Duration statistics (seconds): -mean 3.4 -std 3.0 -min 0.2 -25% 1.2 -50% 2.5 -75% 4.6 -99% 12.7 -99.5% 13.7 -99.9% 15.0 -max 15.9 -Recordings available: 1385 -Features available: 1385 -Supervisions available: 1385 +csj_cuts_eval2.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:02:07 │ +├───────────────────────────┼──────────┤ +│ mean │ 7.1 │ +├───────────────────────────┼──────────┤ +│ std │ 2.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1025 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- number (in 1385 cuts) -- symbol (in 1385 cuts) -- fluent (in 1385 cuts) -- disfluent (in 1385 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## valid -Cuts count: 4000 -Total duration (hh:mm:ss): 05:08:09 -Speech duration (hh:mm:ss): 05:08:09 (100.0%) -Duration statistics (seconds): -mean 4.6 -std 3.8 -min 0.1 -25% 1.5 -50% 3.4 -75% 7.0 -99% 13.8 -99.5% 14.8 -99.9% 16.0 -max 17.3 -Recordings available: 4000 -Features available: 4000 -Supervisions available: 4000 -SUPERVISION custom fields: -- fluent (in 4000 cuts) -- symbol (in 4000 cuts) -- disfluent (in 4000 cuts) -- number (in 4000 cuts) +--------------------------------- -## train -Cuts count: 1291134 -Total duration (hh:mm:ss): 1596:37:27 -Speech duration (hh:mm:ss): 1596:37:27 (100.0%) -Duration statistics (seconds): -mean 4.5 -std 3.6 -min 0.0 -25% 1.6 -50% 3.3 -75% 6.4 -99% 14.0 -99.5% 14.8 -99.9% 16.6 -max 27.8 -Recordings available: 1291134 -Features available: 1291134 -Supervisions available: 1291134 +csj_cuts_eval3.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 865 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:26:44 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.3 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.3 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.7 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 865 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 865 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- disfluent (in 1291134 cuts) -- fluent (in 1291134 cuts) -- symbol (in 1291134 cuts) -- number (in 1291134 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_valid.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 06:40:15 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.4 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.1 │ +├───────────────────────────┼──────────┤ +│ max │ 11.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 3743 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_excluded.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 980 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:56:06 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.1 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 0.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.2 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 980 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 980 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_train.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤════════════╕ +│ Cuts count: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Total duration (hh:mm:ss) │ 1695:29:43 │ +├───────────────────────────┼────────────┤ +│ mean │ 6.7 │ +├───────────────────────────┼────────────┤ +│ std │ 2.9 │ +├───────────────────────────┼────────────┤ +│ min │ 0.1 │ +├───────────────────────────┼────────────┤ +│ 25% │ 4.6 │ +├───────────────────────────┼────────────┤ +│ 50% │ 7.5 │ +├───────────────────────────┼────────────┤ +│ 75% │ 8.9 │ +├───────────────────────────┼────────────┤ +│ 99% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.5% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.9% │ 11.1 │ +├───────────────────────────┼────────────┤ +│ max │ 18.0 │ +├───────────────────────────┼────────────┤ +│ Recordings available: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼────────────┤ +│ Supervisions available: │ 914151 │ +╘═══════════════════════════╧════════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤════════════╤══════════════════════╕ +│ Total speech duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧════════════╧══════════════════════╛ """ diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index e4d996871..58b197922 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -21,24 +21,14 @@ import logging from pathlib import Path from lhotse import CutSet +from lhotse.recipes.csj import CSJSDBParser ARGPARSE_DESCRIPTION = """ -This script gathers all training transcripts of the specified {trans_mode} type -and produces a token_list that would be output set of the ASR system. +This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system. -It splits transcripts by whitespace into lists, then, for each word in the -list, if the word does not appear in the list of user-defined multicharacter -strings, it further splits that word into individual characters to be counted -into the output token set. - -It outputs 4 files into the lang directory: -- trans_mode: the name of transcript mode. If trans_mode was not specified, - this will be an empty file. -- userdef_string: a list of user defined strings that should not be split - further into individual characters. By default, it contains "", "", - "" -- words_len: the total number of tokens in the output set. -- words.txt: a list of tokens in the output set. The length matches words_len. +It outputs 3 files into the lang directory: +- tokens.txt: a list of tokens in the output set. +- lang_type: a file that contains the string "char" """ @@ -50,104 +40,52 @@ def get_args(): ) parser.add_argument( - "--train-cut", type=Path, required=True, help="Path to the train cut" - ) - - parser.add_argument( - "--trans-mode", - type=str, - default=None, - help=( - "Name of the transcript mode to use. " - "If lang-dir is not set, this will also name the lang-dir" - ), + "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" ) parser.add_argument( "--lang-dir", type=Path, - default=None, + default=Path("data/lang_char"), help=( "Name of lang dir. " "If not set, this will default to lang_char_{trans-mode}" ), ) - parser.add_argument( - "--userdef-string", - type=Path, - default=None, - help="Multicharacter strings that do not need to be split", - ) - return parser.parse_args() def main(): args = get_args() - logging.basicConfig( - format=( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" - ), + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), level=logging.INFO, ) - if not args.lang_dir: - p = "lang_char" - if args.trans_mode: - p += f"_{args.trans_mode}" - args.lang_dir = Path(p) + sysdef_string = set(["", "", ""]) - if args.userdef_string: - args.userdef_string = set(args.userdef_string.read_text().split()) - else: - args.userdef_string = set() + # Using disfluent parsing as fluent is a subset of disfluent + parser = CSJSDBParser() - sysdef_string = ["", "", ""] - args.userdef_string.update(sysdef_string) + token_set = set() + logging.info(f"Creating vocabulary from {args.train_cut}.") + train_cut: CutSet = CutSet.from_file(args.train_cut) + for cut in train_cut: + if "_sp" in cut.id: + continue - train_set: CutSet = CutSet.from_file(args.train_cut) - - words = set() - logging.info( - f"Creating vocabulary from {args.train_cut.name}" - f" at {args.trans_mode} mode." - ) - for cut in train_set: - try: - text: str = ( - cut.supervisions[0].custom[args.trans_mode] - if args.trans_mode - else cut.supervisions[0].text - ) - except KeyError: - raise KeyError( - f"Could not find {args.trans_mode} in " - f"{cut.supervisions[0].custom}" - ) - for t in text.split(): - if t in args.userdef_string: - words.add(t) - else: - words.update(c for c in list(t)) - - words -= set(sysdef_string) - words = sorted(words) - words = [""] + words + ["", ""] + text: str = cut.supervisions[0].custom["raw"] + for w in parser.parse(text, sep=" ").split(" "): + token_set.update(w) + token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "words.txt").write_text( - "\n".join(f"{word}\t{i}" for i, word in enumerate(words)) + (args.lang_dir / "tokens.txt").write_text( + "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) ) - (args.lang_dir / "words_len").write_text(f"{len(words)}") - - (args.lang_dir / "userdef_string").write_text( - "\n".join(args.userdef_string) - ) - - (args.lang_dir / "trans_mode").write_text(args.trans_mode) + (args.lang_dir / "lang_type").write_text("char") logging.info("Done.") diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py new file mode 100644 index 000000000..619820a75 --- /dev/null +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -0,0 +1,462 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset): + def __init__( + self, + *args, + transcript_mode: str = "", + return_cuts: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.transcript_mode = transcript_mode + self.return_cuts = True + self._return_cuts = return_cuts + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + batch = super().__getitem__(cuts) + + if self.transcript_mode: + batch["supervisions"]["text"] = [ + supervision.custom[self.transcript_mode] + for cut in batch["supervisions"]["cut"] + for supervision in cut.supervisions + ] + + if not self._return_cuts: + del batch["supervisions"]["cut"] + + return batch + + +class CSJAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--transcript-mode", + type=str, + default="", + help="Mode of transcript in supervision to use.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--musan-dir", type=Path, help="Path to directory with musan cuts. " + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = AsrVariableTranscriptDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + else: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + + test = AsrVariableTranscriptDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get valid cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz") + + @lru_cache() + def excluded_cuts(self) -> CutSet: + logging.info("About to get excluded cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz") + + @lru_cache() + def eval1_cuts(self) -> CutSet: + logging.info("About to get eval1 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz") + + @lru_cache() + def eval2_cuts(self) -> CutSet: + logging.info("About to get eval2 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz") + + @lru_cache() + def eval3_cuts(self) -> CutSet: + logging.info("About to get eval3 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz") diff --git a/egs/csj/ASR/local/utils/tokenizer.py b/egs/csj/ASR/local/utils/tokenizer.py new file mode 100644 index 000000000..c9be72be1 --- /dev/null +++ b/egs/csj/ASR/local/utils/tokenizer.py @@ -0,0 +1,253 @@ +import argparse +from pathlib import Path +from typing import Callable, List, Union + +import sentencepiece as spm +from k2 import SymbolTable + + +class Tokenizer: + text2word: Callable[[str], List[str]] + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Lang related options") + + group.add_argument("--lang", type=Path, help="Path to lang directory.") + + group.add_argument( + "--lang-type", + type=str, + default=None, + help=( + "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " + "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" + ), + ) + + @staticmethod + def Load(lang_dir: Path, lang_type="", oov=""): + + if not lang_type: + assert (lang_dir / "lang_type").exists(), "lang_type not specified." + lang_type = (lang_dir / "lang_type").read_text().strip() + + tokenizer = None + + if lang_type == "bpe": + assert ( + lang_dir / "bpe.model" + ).exists(), f"No BPE .model could be found in {lang_dir}." + tokenizer = spm.SentencePieceProcessor() + tokenizer.Load(str(lang_dir / "bpe.model")) + elif lang_type == "char": + tokenizer = CharTokenizer(lang_dir, oov=oov) + else: + raise NotImplementedError(f"{lang_type} not supported at the moment.") + + return tokenizer + + load = Load + + def PieceToId(self, piece: str) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + piece_to_id = PieceToId + + def IdToPiece(self, id: int) -> str: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + id_to_piece = IdToPiece + + def GetPieceSize(self) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + get_piece_size = GetPieceSize + + def __len__(self) -> int: + return self.get_piece_size() + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsIds(self, input: str) -> List[int]: + return self.EncodeAsIdsBatch([input])[0] + + def EncodeAsPieces(self, input: str) -> List[str]: + return self.EncodeAsPiecesBatch([input])[0] + + def Encode( + self, input: Union[str, List[str]], out_type=int + ) -> Union[List, List[List]]: + if not input: + return [] + + if isinstance(input, list): + if out_type is int: + return self.EncodeAsIdsBatch(input) + if out_type is str: + return self.EncodeAsPiecesBatch(input) + + if out_type is int: + return self.EncodeAsIds(input) + if out_type is str: + return self.EncodeAsPieces(input) + + encode = Encode + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodeIds(self, input: List[int]) -> str: + return self.DecodeIdsBatch([input])[0] + + def DecodePieces(self, input: List[str]) -> str: + return self.DecodePiecesBatch([input])[0] + + def Decode( + self, + input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], + ) -> Union[List[str], str]: + + if not input: + return "" + + if isinstance(input, int): + return self.id_to_piece(input) + elif isinstance(input, str): + raise TypeError( + "Unlike spm.SentencePieceProcessor, cannot decode from type str." + ) + + if isinstance(input[0], list): + if not input[0] or isinstance(input[0][0], int): + return self.DecodeIdsBatch(input) + + if isinstance(input[0][0], str): + return self.DecodePiecesBatch(input) + + if isinstance(input[0], int): + return self.DecodeIds(input) + if isinstance(input[0], str): + return self.DecodePieces(input) + + raise RuntimeError("Unknown input type") + + decode = Decode + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: + if isinstance(input, list): + return self.SplitBatch(input) + elif isinstance(input, str): + return self.SplitBatch([input])[0] + raise RuntimeError("Unknown input type") + + split = Split + + +class CharTokenizer(Tokenizer): + def __init__(self, lang_dir: Path, oov="", sep=""): + assert ( + lang_dir / "tokens.txt" + ).exists(), f"tokens.txt could not be found in {lang_dir}." + token_table = SymbolTable.from_file(lang_dir / "tokens.txt") + assert ( + "#0" not in token_table + ), "This tokenizer does not support disambig symbols." + self._id2sym = token_table._id2sym + self._sym2id = token_table._sym2id + self.oov = oov + self.oov_id = self._sym2id[oov] + self.sep = sep + if self.sep: + self.text2word = lambda x: x.split(self.sep) + else: + self.text2word = lambda x: list(x.replace(" ", "")) + + def piece_to_id(self, piece: str) -> int: + try: + return self._sym2id[piece] + except KeyError: + return self.oov_id + + def id_to_piece(self, id: int) -> str: + return self._id2sym[id] + + def get_piece_size(self) -> int: + return len(self._sym2id) + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + return [ + [i if i in self._sym2id else self.oov for i in self.text2word(text)] + for text in input + ] + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + return [self.sep.join(text) for text in input] + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + return [self.text2word(text) for text in input] + + +def test_CharTokenizer(): + test_single_string = "こんにちは" + test_multiple_string = [ + "今日はいい天気ですよね", + "諏訪湖は綺麗でしょう", + "这在词表外", + "分かち 書き に し た 文章 です", + "", + ] + test_empty_string = "" + sp = Tokenizer.load(Path("lang_char"), "char", oov="") + splitter = sp.split + print(sp.encode(test_single_string, out_type=str)) + print(sp.encode(test_single_string, out_type=int)) + print(sp.encode(test_multiple_string, out_type=str)) + print(sp.encode(test_multiple_string, out_type=int)) + print(sp.encode(test_empty_string, out_type=str)) + print(sp.encode(test_empty_string, out_type=int)) + print(sp.decode(sp.encode(test_single_string, out_type=str))) + print(sp.decode(sp.encode(test_single_string, out_type=int))) + print(sp.decode(sp.encode(test_multiple_string, out_type=str))) + print(sp.decode(sp.encode(test_multiple_string, out_type=int))) + print(sp.decode(sp.encode(test_empty_string, out_type=str))) + print(sp.decode(sp.encode(test_empty_string, out_type=int))) + print(splitter(test_single_string)) + print(splitter(test_multiple_string)) + print(splitter(test_empty_string)) + + +if __name__ == "__main__": + test_CharTokenizer() diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 0c4c6c1ea..7f67c64b6 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -89,9 +89,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh index 052748ca6..52339bb35 100755 --- a/egs/csj/ASR/prepare.sh +++ b/egs/csj/ASR/prepare.sh @@ -32,19 +32,22 @@ # - speech # # By default, this script produces the original transcript like kaldi and espnet. Optionally, you -# can generate other transcript formats by supplying your own config files. A few examples of these +# can add other transcript formats by supplying your own config files. A few examples of these # config files can be found in local/conf. +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=8 stage=-1 stop_stage=100 -csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ -musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan -trans_dir=$csj_dir/retranscript -csj_fbank_dir=/mnt/host/csj_data/fbank +csj_dir=/mnt/host/corpus/csj +musan_dir=/mnt/host/corpus/musan/musan +trans_dir=$csj_dir/transcript +csj_fbank_dir=/mnt/host/corpus/csj/fbank musan_fbank_dir=$musan_dir/fbank csj_manifest_dir=data/manifests musan_manifest_dir=$musan_dir/manifests @@ -60,12 +63,8 @@ log() { if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare CSJ manifest" - # If you want to generate more transcript modes, append the path to those config files at c. - # Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini - # NOTE: In case multiple config files are supplied, the second config file and onwards will inherit - # the segment boundaries of the first config file. if [ ! -e $csj_manifest_dir/.csj.done ]; then - lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4 + lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16 touch $csj_manifest_dir/.csj.done fi fi @@ -85,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ --fbank-dir $csj_fbank_dir parts=( - train - valid eval1 eval2 eval3 + valid + excluded + train ) for part in ${parts[@]}; do - python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz + python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz done touch $csj_fbank_dir/.csj-validated.done fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare CSJ lang" - modes=disfluent - - # If you want prepare the lang directory for other transcript modes, just append - # the names of those modes behind. An example is shown as below:- - # modes="$modes fluent symbol number" - - for mode in ${modes[@]}; do - python local/prepare_lang_char.py --trans-mode $mode \ - --train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \ - --lang-dir lang_char_$mode - done + log "Stage 4: Prepare CSJ lang_char" + python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz + python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -125,6 +116,6 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Show manifest statistics" - python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt - cat $csj_manifest_dir/manifest_statistics.txt + python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt + cat $csj_fbank_dir/manifest_statistics.txt fi diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py new file mode 100644 index 000000000..f5235207a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py @@ -0,0 +1,76 @@ +import logging +from configparser import ConfigParser + +import requests + + +def escape_html(text: str): + """ + Escapes all html characters in text + :param str text: + :rtype: str + """ + return text.replace("&", "&").replace("<", "<").replace(">", ">") + + +class TelegramStreamIO(logging.Handler): + + API_ENDPOINT = "https://api.telegram.org" + MAX_MESSAGE_LEN = 4096 + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s at %(funcName)s " + "(line %(lineno)s):\n\n%(message)s" + ) + + def __init__(self, tg_configfile: str): + super(TelegramStreamIO, self).__init__() + config = ConfigParser() + if not config.read(tg_configfile): + raise FileNotFoundError( + f"{tg_configfile} not found. " "Retry without --telegram-cred flag." + ) + config = config["TELEGRAM"] + token = config["token"] + self.chat_id = config["chat_id"] + self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage" + + @staticmethod + def setup_logger(params): + if not params.telegram_cred: + return + formatter = logging.Formatter( + f"{params.exp_dir.name} %(asctime)s \n%(message)s" + ) + tg = TelegramStreamIO(params.telegram_cred) + tg.setLevel(logging.WARN) + tg.setFormatter(formatter) + logging.getLogger("").addHandler(tg) + + def emit(self, record: logging.LogRecord): + """ + Emit a record. + Send the record to the Web server as a percent-encoded dictionary + """ + data = { + "chat_id": self.chat_id, + "text": self.format(self.mapLogRecord(record)), + "parse_mode": "HTML", + } + try: + requests.get(self.url, json=data) + # return response.json() + except Exception as e: + logging.error(f"Failed to send telegram message: {repr(e)}") + pass + + def mapLogRecord(self, record): + """ + Default implementation of mapping the log record into a dict + that is sent as the CGI data. Overwrite in your class. + Contributed by Franz Glasner. + """ + + for k, v in record.__dict__.items(): + if isinstance(v, str): + setattr(record, k, escape_html(v)) + return record diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a48591198 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..f5a1d750d --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,846 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --lang data/lang_char \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir. It should contain at least a word table.", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=30, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.text2word(sp.decode(hyp))) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = sp.text2word(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + return test_set_wers + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + if not params.res_dir: + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and are defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + decoding_graph = None + word_table = None + + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, + test_set_name=subdir, + results_dict=results_dict, + ) + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..1ce277aa6 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 100755 index 000000000..ebdb596a5 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/csj/ASR + +repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp_fluent/pretrained.pt" + +cd exp_fluent +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --lang $repo/data/lang_char \ + --exp-dir $repo/exp_fluent/ \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp_fluent + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14 + +Please also have a look at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14/blob/main/export-for-ncnn-fluent.sh + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + assert params.blank_id == 0, params.blank_id + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100644 index 000000000..2d45ecca3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/csj/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang data/lang_char + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + # You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100644 index 000000000..ab7c8748a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 + +""" +Usage: +# use -O to skip assertions and avoid some of the TracerWarnings +python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100644 index 000000000..d84cf04a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --lang data/lang_char \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +from typing import List, Optional + +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = Tokenizer.load(args.lang, args.lang_type) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..482ebcfef --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..16c2bf28d --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..522bbaff9 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100644 index 000000000..932026868 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..a7ef73bcb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..566c317ff --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..9700dd89e --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --lang data/lang_char \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decode import save_results +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if not params.res_dir: + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + cuts=getattr(csj_corpus, f"{subdir}_cuts")(), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, test_set_name=subdir, results_dict=results_dict + ) + + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..0a82ccfa4 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/csj/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py new file mode 120000 index 000000000..958c99e85 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py @@ -0,0 +1 @@ +../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..601de2c41 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1304 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..d1913d718 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1305 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/.gitignore b/egs/gigaspeech/ASR/.gitignore index 5592679cc..8dec2d86d 100644 --- a/egs/gigaspeech/ASR/.gitignore +++ b/egs/gigaspeech/ASR/.gitignore @@ -1 +1,2 @@ log-* +.DS_Store \ No newline at end of file diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d78e26240..9437c935c 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -183,23 +183,18 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -221,9 +216,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -256,9 +249,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -304,9 +295,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -362,9 +351,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" - ) + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 6fac07f93..a1cfe6e75 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index 51406667e..d7035a1f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -476,14 +476,12 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] - for ref_text in texts: + for cut_id, ref_text in zip(cut_ids, texts): ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words)) for lm_scale in results.keys(): results[lm_scale].extend(this_batch) @@ -493,9 +491,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -528,9 +524,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -705,9 +699,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..3b94f0c4b 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,13 +78,10 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 2965cde18..4883d04d8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -521,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -641,9 +637,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 8209ee3ec..07beeb1f0 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,9 +77,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 6410249db..1c71be0f9 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -134,9 +134,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/tedlium3/ASR/local/generate_unique_lexicon.py b/egs/gigaspeech/ASR/local/generate_unique_lexicon.py similarity index 100% rename from egs/tedlium3/ASR/local/generate_unique_lexicon.py rename to egs/gigaspeech/ASR/local/generate_unique_lexicon.py diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 48d10a157..31abe7fff 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,19 +98,13 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index fd2532741..bd255dc6a 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index c87686e1e..4d5d2b8f9 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -195,8 +195,7 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev " - "(speeds up training)", + help="Should we use only 1000 utterances for dev (speeds up training)", ) def train_dataloaders( @@ -216,13 +215,9 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -244,9 +239,7 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -289,9 +282,7 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -347,9 +338,7 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -398,7 +387,7 @@ class GigaSpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info(f"About to get train_{self.args.subset} cuts") - path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" cuts_train = CutSet.from_jsonl_lazy(path) return cuts_train @@ -406,7 +395,7 @@ class GigaSpeechAsrDataModule: def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "cuts_DEV.jsonl.gz" + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" ) if self.args.small_dev: return cuts_valid.subset(first=1000) @@ -416,4 +405,6 @@ class GigaSpeechAsrDataModule: @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py new file mode 100755 index 000000000..76306fc4c --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corp. (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./pruned_transducer_stateless7/compute_ppl.py \ + --ngram-lm-path ./download/lm/3gram_pruned_1e7.arpa + +""" + + +import argparse +import logging +import math +from typing import Dict, List, Optional, Tuple + +import kenlm +import torch +from asr_datamodule import GigaSpeechAsrDataModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--ngram-lm-path", + type=str, + default="download/lm/3gram_pruned_1e7.arpa", + help="The lang dir containing word table and LG graph", + ) + + return parser + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: kenlm.Model, +) -> Dict[str, float]: + """ + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + A ngram lm of kenlm.Model object. + Returns: + Return the perplexity of the giving dataset. + """ + sum_score_log = 0 + sum_n = 0 + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + for text in texts: + sum_n += len(text.split()) + 1 + sum_score_log += -1 * model.score(text) + + ppl = math.pow(10.0, sum_score_log / sum_n) + + return ppl + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + logging.info("About to load ngram LM") + model = kenlm.Model(args.ngram_lm_path) + + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + ppl = decode_dataset( + dl=test_dl, + model=model, + ) + logging.info(f"{test_set} PPL: {ppl}") + + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 5849a3471..72f74c968 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -19,40 +19,40 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 (4) fast beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -76,9 +76,9 @@ from beam_search import ( ) from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model - from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) @@ -86,6 +86,7 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -98,9 +99,9 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, + default=30, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -123,6 +124,17 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -188,8 +200,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -258,9 +269,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -275,10 +284,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -398,9 +404,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -411,9 +415,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = post_processing(results) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -421,9 +423,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -433,10 +433,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -483,6 +480,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -495,7 +495,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -505,38 +505,85 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index cff9c7377..b6190e8a6 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 83ae25561..578bd9218 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -19,31 +19,30 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" ./pruned_transducer_stateless2/train.py \ - --world-size 4 \ + --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 300 + --max-duration 120 # For mix precision training: ./pruned_transducer_stateless2/train.py \ - --world-size 4 \ + --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 + --max-duration 200 """ import argparse +import copy import logging import warnings from pathlib import Path @@ -72,14 +71,15 @@ from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -118,10 +118,10 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + default=1, + help="""Resume training from this epoch. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -178,8 +178,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -202,8 +201,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -246,7 +244,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=20, + default=30, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -254,6 +252,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -391,6 +402,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: @@ -398,7 +410,7 @@ def load_checkpoint_if_available( If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -410,6 +422,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -419,7 +433,7 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: + elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -429,6 +443,7 @@ def load_checkpoint_if_available( saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -455,7 +470,8 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, @@ -469,6 +485,8 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer used in the training. sampler: @@ -482,6 +500,7 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -501,14 +520,14 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute transducer loss given the model and its inputs. Args: params: @@ -553,23 +572,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -581,7 +593,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -615,13 +627,14 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -647,6 +660,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -673,6 +688,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, + model_avg=model_avg, sp=sp, batch=batch, is_training=True, @@ -701,6 +717,7 @@ def train_one_epoch( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -732,9 +749,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -806,7 +821,16 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoints = load_checkpoint_if_available(params=params, model=model) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if world_size > 1: @@ -865,10 +889,10 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -878,6 +902,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sp=sp, @@ -896,6 +921,7 @@ def run(rank, world_size, args): save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, @@ -911,7 +937,7 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, diff --git a/egs/librispeech/ASR/.gitignore b/egs/librispeech/ASR/.gitignore index 5592679cc..1c26f7978 100644 --- a/egs/librispeech/ASR/.gitignore +++ b/egs/librispeech/ASR/.gitignore @@ -1 +1,3 @@ log-* +.DS_Store +run*.sh diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index c366650bb..82cef9817 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,6 +1,6 @@ # Introduction -Please refer to for how to run models in this recipe. +Please refer to for how to run models in this recipe. [./RESULTS.md](./RESULTS.md) contains the latest results. @@ -19,16 +19,37 @@ The following table lists the differences among them. | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | -| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training | +| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training + delay penalty | | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| +| `pruned_transducer_stateless7_ctc` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head| +| `pruned_transducer_stateless7_ctc_bs` | Zipformer | Embedding + Conv1d | pruned_transducer_stateless7_ctc + blank skip | +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 | +| `pruned_transducer_stateless7_streaming_multi` | Streaming Zipformer | Embedding + Conv1d | same as pruned_transducer_stateless7_streaming, trained on LibriSpeech + GigaSpeech | +| `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | | `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model | | `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model | -| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | +| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | +| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + +# CTC + +| | Encoder | Comment | +|------------------------------|--------------------|------------------------------| +| `conformer-ctc` | Conformer | Use auxiliary attention head | +| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | +| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | + +# MMI + +| | Encoder | Comment | +|------------------------------|-----------|---------------------------------------------------| +| `conformer-mmi` | Conformer | | +| `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 43cd67c85..2ca0558ab 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,776 @@ ## Results +### pruned_transducer_stateless7 (zipformer + multidataset(LibriSpeech + GigaSpeech + CommonVoice 13.0)) + +See for more details. + +[pruned_transducer_stateless7](./pruned_transducer_stateless7) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 1.91 | 4.06 | --epoch 30 --avg 7 | +| modified_beam_search | 1.90 | 3.99 | --epoch 30 --avg 7 | +| fast_beam_search | 1.90 | 3.98 | --epoch 30 --avg 7 | + + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless7/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --use-multidataset 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7/exp +``` + +The decoding commands are: +```bash +# greedy_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +# modified_beam_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast_beam_search +./pruned_transducer_stateless7/decode.py \ + --epoch 30 \ + --avg 7 \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +``` + +### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset) + +#### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + +Number of model parameters: 70369391, i.e., 70.37 M + +##### training on full librispeech + full gigaspeech (with giga_prob=0.9) + +The WERs are: + + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 2.43 | 6.0 | --epoch 20 --avg 4 | simulated streaming | +| greedy search | 320ms | 2.47 | 6.13 | --epoch 20 --avg 4 | chunk-wise | +| fast beam search | 320ms | 2.43 | 5.99 | --epoch 20 --avg 4 | simulated streaming | +| fast beam search | 320ms | 2.8 | 6.46 | --epoch 20 --avg 4 | chunk-wise | +| modified beam search | 320ms | 2.4 | 5.96 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 320ms | 2.42 | 6.03 | --epoch 20 --avg 4 | chunk-size | +| greedy search | 640ms | 2.26 | 5.58 | --epoch 20 --avg 4 | simulated streaming | +| greedy search | 640ms | 2.33 | 5.76 | --epoch 20 --avg 4 | chunk-wise | +| fast beam search | 640ms | 2.27 | 5.54 | --epoch 20 --avg 4 | simulated streaming | +| fast beam search | 640ms | 2.37 | 5.75 | --epoch 20 --avg 4 | chunk-wise | +| modified beam search | 640ms | 2.22 | 5.5 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 640ms | 2.25 | 5.69 | --epoch 20 --avg 4 | chunk-size | + +The model also has good WERs on GigaSpeech. The following WERs are achieved on GigaSpeech test and dev sets: + +| decoding method | chunk size | dev | test | comment | decoding mode | +|----------------------|------------|-----|------|------------|---------------------| +| greedy search | 320ms | 12.08 | 11.98 | --epoch 20 --avg 4 | simulated streaming | +| greedy search | 640ms | 11.66 | 11.71 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 320ms | 11.95 | 11.83 | --epoch 20 --avg 4 | simulated streaming | +| modified beam search | 320ms | 11.65 | 11.56 | --epoch 20 --avg 4 | simulated streaming | + + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command is: + +```bash +./pruned_transducer_stateless7_streaming_multi/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming_multi/exp \ + --full-libri 1 \ + --giga-prob 0.9 \ + --max-duration 750 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The simulated streaming decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_streaming_multi/decode.py \ + --epoch 20 \ + --avg 4 \ + --exp-dir ./pruned_transducer_stateless7_streaming_multi/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --right-padding 64 \ + --decoding-method $m +done +``` + +The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless7_streaming_multi/streaming_decode.py \ + --epoch 20 \ + --avg 4 \ + --exp-dir ./pruned_transducer_stateless7_streaming_multi/exp \ + --decoding-method $m \ + --decode-chunk-len 32 \ + --num-decode-streams 2000 +done +``` + + +#### Smaller model + +We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is: + +```bash +./pruned_transducer_stateless7_streaming_multi/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming_multi/exp \ + --full-libri 1 \ + --giga-prob 0.9 \ + --num-encoder-layers "2,2,2,2,2" \ + --feedforward-dims "256,256,512,512,256" \ + --nhead "4,4,4,4,4" \ + --encoder-dims "128,128,128,128,128" \ + --attention-dims "96,96,96,96,96" \ + --encoder-unmasked-dims "96,96,96,96,96" \ + --max-duration 1200 \ + --master-port 12345 +``` + +You can find this pretrained small model and its training logs, decoding logs, and decoding +results at: + + + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 5.95 | 15.03 | --epoch 30 --avg 1 | simulated streaming | +| greedy search | 640ms | 5.61 | 13.86 | --epoch 30 --avg 1 | simulated streaming | +| modified beam search | 320ms | 5.72 | 14.34 | --epoch 30 --avg 1 | simulated streaming | +| modified beam search | 640ms | 5.43 | 13.16 | --epoch 30 --avg 1 | simulated streaming | +| fast beam search | 320ms | 5.88 | 14.45 | --epoch 30 --avg 1 | simulated streaming | +| fast beam search | 640ms | 5.48 | 13.31 | --epoch 30 --avg 1 | simulated streaming | + +This small model achieves the following WERs on GigaSpeech test and dev sets: + +| decoding method | chunk size | dev | test | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 17.57 | 17.2 | --epoch 30 --avg 1 | simulated streaming | +| modified beam search | 320ms | 16.98 | 11.98 | --epoch 30 --avg 1 | simulated streaming | + +You can find the tensorboard logs at . + +### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 70369391, i.e., 70.37 M + +##### training on full librispeech + +The WERs are: + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 3.15 | 8.09 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 320ms | 3.17 | 8.24 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 320ms | 3.2 | 8.04 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 320ms | 3.36 | 8.19 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 3.12 | 8.11 | --epoch 30 --avg 9 | chunk-size | +| greedy search | 640ms | 2.97 | 7.5 | --epoch 30 --avg 9 | simulated streaming | +| greedy search | 640ms | 2.98 | 7.67 | --epoch 30 --avg 9 | chunk-wise | +| fast beam search | 640ms | 3.02 | 7.47 | --epoch 30 --avg 9 | simulated streaming | +| fast beam search | 640ms | 2.96 | 7.61 | --epoch 30 --avg 9 | chunk-wise | +| modified beam search | 640ms | 2.94 | 7.36 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 640ms | 2.95 | 7.53 | --epoch 30 --avg 9 | chunk-size | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command is: + +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 750 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The simulated streaming decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search fast_beam_search modified_beam_search; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method $m +done +``` + +The streaming chunk-size decoding command (e.g., chunk-size=320ms) is: +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 30 \ + --avg 9 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding-method $m \ + --decode-chunk-len 32 \ + --num-decode-streams 2000 +done +``` +We also support decoding with neural network LMs. After combining with language models, the WERs are +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming | + +Please use the following command for `modified_beam_search_lm_shallow_fusion`: +```bash +for lm_scale in $(seq 0.15 0.01 0.38); do + for beam_size in 4 8 12; do + ./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model False \ + --beam-size $beam_size \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp-large-LM \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 + done +done +``` + +Please use the following command for `modified_beam_search_lm_rescore`: +```bash +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --use-averaged-model True \ + --beam-size 8 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore \ + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 +``` + +Please use the following command for `modified_beam_search_lm_rescore_LODR`: +```bash +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --use-averaged-model True \ + --beam-size 8 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore_LODR \ + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 \ + --tokens-ngram 2 \ + --backoff-id 500 +``` + +A well-trained RNNLM can be found here: . The bi-gram used in LODR decoding +can be found here: . + + +#### Smaller model + +A smaller model (~20M params) is also available with configuration based on [this comment](https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740). The WERs are: + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 3.94 | 9.79 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 3.88 | 9.53 | --epoch 30 --avg 9 | simulated streaming | + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +### zipformer_mmi (zipformer with mmi loss) + +See for more details. + +[zipformer_mmi](./zipformer_mmi) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 69136519, i.e., 69.14 M + +| | test-clean | test-other | comment | +| ---------------------- | ---------- | ---------- | ------------------- | +| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | +| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | +| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | +| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | +| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer_mmi/train.py \ + --world-size 4 \ + --master-port 12345 \ + --num-epochs 30 \ + --start-epoch 1 \ + --lang-dir data/lang_bpe_500 \ + --max-duration 500 \ + --full-libri 1 \ + --use-fp16 1 \ + --exp-dir zipformer_mmi/exp +``` + +The decoding commands for the transducer branch are: +```bash +export CUDA_VISIBLE_DEVICES="5" + +for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do + ./zipformer_mmi/decode.py \ + --epoch 30 \ + --avg 10 \ + --exp-dir ./zipformer_mmi/exp/ \ + --max-duration 100 \ + --lang-dir data/lang_bpe_500 \ + --nbest-scale 1.2 \ + --hp-scale 1.0 \ + --decoding-method $m +done +``` + +### pruned_transducer_stateless7_ctc_bs (zipformer with transducer loss and ctc loss using blank skip) + +See https://github.com/k2-fsa/icefall/pull/730 for more details. + +[pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 76804822, i.e., 76.80 M + +Test on 8-card V100 cluster, with 4-card busy and 4-card idle. + +#### greedy_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.28 | 5.53 | 48.939 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.24 | 5.18 | 91.900 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **1.88 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +#### modified_beam_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.26 | 5.44 | 80.446 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.20 | 5.12 | 283.676 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **3.53 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +The training commands for the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --ctc-loss-scale 0.2 \ + --master-port 12535 +``` + +The decoding commands for the transducer branch of the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +The decoding commands for the transducer branch of the model without blank skip ([pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss) + +See for more details. + +[pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 70561891, i.e., 70.56 M + +| | test-clean | test-other | comment | +|--------------------------|------------|-------------|--------------------| +| greedy search | 2.23 | 5.19 | --epoch 30 --avg 8 | +| modified beam search | 2.21 | 5.12 | --epoch 30 --avg 8 | +| fast beam search | 2.23 | 5.18 | --epoch 30 --avg 8 | +| ctc decoding | 2.48 | 5.82 | --epoch 30 --avg 9 | +| 1best | 2.43 | 5.22 | --epoch 30 --avg 9 | +| nbest | 2.43 | 5.22 | --epoch 30 --avg 9 | +| nbest rescoring | 2.34 | 5.05 | --epoch 30 --avg 9 | +| whole lattice rescoring | 2.34 | 5.04 | --epoch 30 --avg 9 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --ctc-loss-scale 0.2 \ + --master-port 12535 +``` + +The decoding commands for the transducer branch are: +```bash +for m in greedy_search fast_beam_search modified_beam_search ; do + for epoch in 30; do + for avg in 8; do + ./pruned_transducer_stateless7_ctc/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +The decoding commands for the ctc branch are: +```bash +for m in ctc-decoding nbest nbest-rescoring whole-lattice-rescoring; do + for epoch in 30; do + for avg in 9; do + ./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 100 \ + --decoding-method $m \ + --hlg-scale 0.6 \ + --lm-dir data/lm + done + done +done +``` + + +### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty) + +#### [conformer_ctc3](./conformer_ctc3) + +It implements Conformer model training with CTC loss. +For streaming mode, it supports symbol delay penalty. + +See for more details. + +##### training on full librispeech + +This model contains 12 encoder layers. The number of model parameters is 77352694. + +The WERs are: + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|----------------------| +| ctc-decoding | 3.09 | 7.62 | --epoch 25 --avg 7 | +| 1best | 2.87 | 6.44 | --epoch 25 --avg 7 | +| nbest | 2.88 | 6.5 | --epoch 25 --avg 7 | +| nbest-rescoring | 2.71 | 6.1 | --epoch 25 --avg 7 | +| whole-lattice-rescoring | 2.71 | 6.04 | --epoch 25 --avg 7 | + +The training command is: + +```bash +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 25 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/full \ + --full-libri 1 \ + --max-duration 300 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The decoding command using different methods is: +```bash +for method in ctc-decoding 1best nbest nbest-rescoring whole-lattice-rescoring; do + ./conformer_ctc3/decode.py \ + --epoch 25 \ + --avg 7 \ + --exp-dir conformer_ctc3/exp \ + --max-duration 300 \ + --decoding-method $method \ + --manifest-dir data/fbank \ + --lm-dir data/lm \ +done +``` + +Pretrained models, training logs, decoding logs, and decoding results +are available at + + +The command to train a streaming model with symbol delay penalty is: +```bash +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 \ + --delay-penalty 0.1 +``` +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +```bash +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +``` +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +```bash +./conformer_ctc3/decode.py \ + --epoch 25 \ + --avg 7 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 300 \ + --decoding-method ctc-decoding \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --manifest-dir data/fbank_ali +``` +Note: It supports to calculate symbol delay with following decoding methods: + - ctc-greedy-search + - ctc-decoding + - 1best + + +### pruned_transducer_stateless8 (zipformer + multidataset) + +See for more details. + +[pruned_transducer_stateless8](./pruned_transducer_stateless8) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 1.81 | 4.18 | --epoch 20 --avg 4 | +| fast_beam_search | 1.82 | 4.15 | --epoch 20 --avg 4 | +| modified_beam_search | 1.78 | **4.08** | --epoch 20 --avg 4 | +| greedy_search | 1.84 | 4.3 | --epoch 19 --avg 8 | +| fast_beam_search |**1.77** | 4.25 | --epoch 19 --avg 8 | +| modified_beam_search | 1.81 | 4.16 | --epoch 19 --avg 8 | + + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless8/train.py \ + --world-size 8 \ + --num-epochs 20 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless8/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --master-port 12535 \ + --giga-prob 0.9 +``` + +The decoding commands are: +```bash +for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in $(seq 20 -1 10); do + for avg in $(seq 9 -1 1); do + ./pruned_transducer_stateless8/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + + ### pruned_transducer_stateless7 (zipformer) See for more details. @@ -19,9 +790,13 @@ Number of model parameters: 70369391, i.e., 70.37 M | | test-clean | test-other | comment | |----------------------|------------|-------------|----------------------------------------| -| greedy search | 2.17 | 5.23 | --epoch 39 --avg 6 --max-duration 600 | -| modified beam search | 2.15 | 5.20 | --epoch 39 --avg 6 --max-duration 600 | -| fast beam search | 2.15 | 5.22 | --epoch 39 --avg 6 --max-duration 600 | +| greedy search | 2.17 | 5.23 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search | 2.15 | 5.20 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 1.99 | 4.73 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 1.94 | 4.73 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + RNNLM + LODR | 1.91 | 4.57 | --epoch 30 --avg 9 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.91 | 4.51 | --epoch 30 --avg 9 --max-duration 600 | +| fast beam search | 2.15 | 5.22 | --epoch 30 --avg 9 --max-duration 600 | The training commands are: ```bash @@ -56,7 +831,10 @@ for m in greedy_search fast_beam_search modified_beam_search ; do done ``` - +Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in +this [PR](https://github.com/k2-fsa/icefall/pull/942) to address the +problem of emitting the first symbol at the very beginning. If you need a +model without this issue, please download the model from here: ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) @@ -160,6 +938,9 @@ The WERs are: | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | | modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | +| modified_beam_search + TransformerLM shallow fusion | 2.37 | 6.48 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM + LODR | 2.24 | 5.89 | --iter 468000 --avg 16 | +| modified_beam_search + TransformerLM + LODR | 2.19 | 5.90 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -214,9 +995,12 @@ for m in greedy_search fast_beam_search modified_beam_search; do done ``` -To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM -can be found here: +You may also decode using shallow fusion with external neural network LM. To do so you need to +download a well-trained NN LM: +RNN LM: +Transformer LM: +```bash for iter in 472000; do for avg in 8 10 12 14 16 18; do ./lstm_transducer_stateless2/decode.py \ @@ -224,16 +1008,47 @@ for iter in 472000; do --avg $avg \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ - --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ - --rnn-lm-epoch 99 \ - --rnn-lm-avg 1 \ - --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ done done +``` + +You may also decode using LODR + LM shallow fusion. This decoding method is proposed in . +It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be +generated by `generate-lm.sh`, or you may download it from . + +The decoding command is as follows: + +```bash +for iter in 472000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_LODR \ + --beam 4 \ + --max-contexts 4 \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir /ceph-data4/yangxiaoyu/pretrained_models/LM/icefall-librispeech-rnn-lm/exp \ + --lm-epoch 99 \ + --lm-scale 0.4 \ + --lm-avg 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 + done +done +``` +Note that you can also set `--lm-type transformer` to use transformer LM during LODR. But it will be slower +because it has not been optimized. The pre-trained transformer LM is available at Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1392,6 +2207,9 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + RNNLM + LODR | 2.23 | 5.17 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + TransformerLM shallow fusion | 2.27 | 5.26 | --epoch 30 --avg 10 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 2.22 | 5.11 | --epoch 30 --avg 10 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | ```bash @@ -1754,6 +2572,9 @@ subset so that the gigaspeech dataloader never exhausts. |-------------------------------------|------------|------------|---------------------------------------------| | greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + rnnlm + LODR | 1.77 | 3.99 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + TransformerLM + LODR | 1.75 | 3.94 | --iter 1224000 --avg 14 --max-duration 600 | | fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 | The training commands are: @@ -1799,6 +2620,66 @@ for iter in 1224000; do done done ``` +You may also decode using shallow fusion with external neural network LM. To do so you need to +download a well-trained NN LM: +RNN LM: +Transformer LM: + +```bash +rnn_lm_scale=0.3 + +for iter in 1224000; do + for avg in 14; do + for method in modified_beam_search_rnnlm_shallow_fusion ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 \ + --rnn-lm-scale $rnn_lm_scale \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + done + done +done +``` + +If you want to try out with LODR decoding, use the following command. This assums you have a bi-gram LM trained on LibriSpeech text. You can also download the bi-gram LM from here and put it under the directory `data/lang_bpe_500`. + +```bash +rnn_lm_scale=0.4 + +for iter in 1224000; do + for avg in 14; do + for method in modified_beam_search_rnnlm_LODR ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 \ + --rnn-lm-scale $rnn_lm_scale \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.14 + done + done +done +``` The pretrained models, training logs, decoding logs, and decoding results can be found at diff --git a/egs/librispeech/ASR/add_alignments.sh b/egs/librispeech/ASR/add_alignments.sh index 5e4480bf6..6c47d25a2 100755 --- a/egs/librispeech/ASR/add_alignments.sh +++ b/egs/librispeech/ASR/add_alignments.sh @@ -2,11 +2,51 @@ set -eou pipefail -alignments_dir=data/alignment +# align could be in ("mfa", "torchaudio") +# We recommend "torchaudio" +align="torchaudio" + +# It adds alignments to the existing fbank features dir (e.g., data/fbank) +# and save cuts to a new dir (e.g., data/fbank_ali). cuts_in_dir=data/fbank cuts_out_dir=data/fbank_ali -python3 ./local/add_alignment_librispeech.py \ - --alignments-dir $alignments_dir \ - --cuts-in-dir $cuts_in_dir \ - --cuts-out-dir $cuts_out_dir +if [ $align == "mfa" ]; then + # It add alignments from https://github.com/CorentinJ/librispeech-alignments, + # generated using the Montreal Forced Aligner (https://montreal-forced-aligner.readthedocs.io). + alignments_dir=data/alignment + + python3 ./local/add_alignment_librispeech.py \ + --alignments-dir $alignments_dir \ + --cuts-in-dir $cuts_in_dir \ + --cuts-out-dir $cuts_out_dir +elif [ $align == "torchaudio" ]; then + # See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/bin/modes/workflows.py for details. + # + # It use a pretrained ASR model from torchaudio to generate alignments. + # It will attach word-level alignment information (start, end, and score) to the + # supervisions in each cut. + mkdir -p $cuts_out_dir + + parts=( + train-clean-100 + train-clean-360 + train-other-500 + test-clean + test-other + dev-clean + dev-other + ) + + echo "The alignments will be saved to $cuts_out_dir" + for part in ${parts[@]}; do + echo "Start to align $part" + lhotse workflows align-with-torchaudio --dont-normalize-text \ + $cuts_in_dir/librispeech_cuts_${part}.jsonl.gz \ + $cuts_out_dir/librispeech_cuts_${part}.jsonl.gz + done + echo "Finished" +else + echo "align is expected to be in ('mfa', 'torchaudio'), but got $align" + exit 1 +fi diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2828e309e..42e14abac 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -231,9 +231,7 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -258,9 +256,7 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return CutSet.from_cuts(cuts) @@ -289,9 +285,7 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = ( - out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" - ) + out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 6fac07f93..a1cfe6e75 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,9 +160,7 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -182,18 +180,14 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -227,9 +221,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -348,9 +340,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -366,9 +356,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -638,9 +626,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -708,31 +696,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -771,9 +750,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -785,9 +762,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -821,13 +796,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3f3b1acda..7e0bf5b7b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -551,9 +551,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -568,9 +566,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -602,9 +598,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -809,9 +803,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index 28c28df01..fbcbd7b29 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -157,9 +157,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index 1f2f3b137..52d2eda3b 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module): mean of the output is taken. (3) "sum": the output will be summed. """ super().__init__() - assert 0.0 <= label_smoothing < 1.0 + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction self.ignore_index = ignore_index self.label_smoothing = label_smoothing self.reduction = reduction @@ -82,13 +83,10 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a2c0a5486..30def9c40 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -236,10 +236,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -300,9 +299,7 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -427,9 +424,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..8e0f73d05 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -132,17 +128,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 6419f6816..99fe64793 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,9 +393,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -422,9 +420,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -453,9 +449,7 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) - .sum() - .item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() ) return loss, info @@ -568,9 +562,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -695,10 +687,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -733,9 +725,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 00ca027a7..0566cfc81 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,9 +151,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss() else: @@ -181,18 +179,13 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if ( - isinstance(self.use_feat_batchnorm, bool) - and self.use_feat_batchnorm - ): + if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -273,23 +266,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -350,23 +337,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -639,9 +620,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -843,9 +822,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -866,9 +843,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 1375d7245..356d3f21b 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,11 +18,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ -from scaling import ScaledLinear - class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -76,9 +75,7 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -94,9 +91,7 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear( - embed_dim, 3 * embed_dim, bias=bias - ) + self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -107,12 +102,8 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = nn.Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index b906d2650..09f1eb000 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,9 +29,8 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from torch import Tensor, nn from subsampling import Conv2dSubsampling - +from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask @@ -182,9 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -356,9 +353,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -373,9 +368,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -650,9 +643,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -721,31 +714,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -784,9 +768,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -794,13 +776,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -834,13 +812,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -863,9 +837,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 97f2f2d39..0b271a51c 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -658,9 +658,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -675,9 +673,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,9 +705,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -852,9 +846,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +875,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -985,9 +979,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 584b3c3fc..7892b03c6 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,6 +47,7 @@ import logging from pathlib import Path import torch +from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -55,10 +56,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from conformer import Conformer - -from icefall.utils import str2bool from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_parser(): @@ -177,9 +176,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -206,9 +205,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/subsampling.py b/egs/librispeech/ASR/conformer_ctc2/subsampling.py index 3fcb4196f..85a4dc8df 100644 --- a/egs/librispeech/ASR/conformer_ctc2/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc2/subsampling.py @@ -24,10 +24,9 @@ from scaling import ( ScaledConv2d, ScaledLinear, ) -from torch import nn -class Conv2dSubsampling(nn.Module): +class Conv2dSubsampling(torch.nn.Module): """Convolutional 2D subsampling (to 1/4 length). Convert an input of shape (N, T, idim) to an output @@ -61,7 +60,7 @@ class Conv2dSubsampling(nn.Module): assert in_channels >= 7 super().__init__() - self.conv = nn.Sequential( + self.conv = torch.nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 9d9c2af1f..121fdb256 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -166,13 +164,6 @@ def get_parser(): """, ) - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - parser.add_argument( "--initial-lr", type=float, @@ -505,11 +496,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -522,14 +509,6 @@ def compute_loss( nnet_output, encoder_memory, memory_mask = model( feature, supervisions, warmup=warmup ) - # logging.info('feature shape: {}'.format(feature.shape)) - # logging.info('nnet_output shape: {}'.format(nnet_output.shape)) - # logging.info('encoder_memory shape: {}'.format(encoder_memory.shape)) - # logging.info('memory_mask shape: {}'.format(memory_mask.shape)) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -546,9 +525,7 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError( - f"Unsupported type of graph compiler: {type(graph_compiler)}" - ) + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -575,9 +552,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -595,9 +570,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -791,9 +764,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info[ - "att_loss" - ] == float("inf"): + if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( + "inf" + ): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -806,9 +779,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -957,10 +928,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index fa179acc0..d3443dc94 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,19 +21,17 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling from attention import MultiheadAttention -from torch.nn.utils.rnn import pad_sequence - +from label_smoothing import LabelSmoothingLoss from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledLinear, ScaledEmbedding, + ScaledLinear, ) - +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -210,9 +208,7 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) return x, mask @@ -261,23 +257,17 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -338,23 +328,17 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -417,7 +401,6 @@ class TransformerEncoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, layer_dropout: float = 0.075, - activation: str = "relu", ) -> None: super(TransformerEncoderLayer, self).__init__() @@ -443,11 +426,6 @@ class TransformerEncoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - # def __setstate__(self, state): - # if "activation" not in state: - # state["activation"] = nn.functional.relu - # super(TransformerEncoderLayer, self).__setstate__(state) - def forward( self, src: torch.Tensor, @@ -539,7 +517,6 @@ class TransformerDecoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, layer_dropout: float = 0.075, - # activation: str = "relu", normalize_before: bool = True, ) -> None: super(TransformerDecoderLayer, self).__init__() @@ -564,11 +541,6 @@ class TransformerDecoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - # def __setstate__(self, state): - # if "activation" not in state: - # state["activation"] = nn.functional.relu - # super(TransformerDecoderLayer, self).__setstate__(state) - def forward( self, tgt: torch.Tensor, @@ -653,17 +625,6 @@ class TransformerDecoderLayer(nn.Module): return tgt -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers @@ -708,7 +669,7 @@ class TransformerEncoder(nn.Module): """ output = src - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, src_mask=mask, @@ -769,7 +730,7 @@ class TransformerDecoder(nn.Module): """ output = tgt - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, memory, @@ -982,9 +943,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -1005,9 +964,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc3/__init__.py b/egs/librispeech/ASR/conformer_ctc3/__init__.py new file mode 120000 index 000000000..b24e5e357 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/__init__.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/conformer.py b/egs/librispeech/ASR/conformer_ctc3/conformer.py new file mode 120000 index 000000000..3b84b9573 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/conformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py new file mode 100755 index 000000000..e6327bb5e --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -0,0 +1,1049 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) decode in non-streaming mode (take ctc-decoding as an example) +./conformer_ctc3/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) decode in streaming mode (take ctc-decoding as an example) +./conformer_ctc3/decode.py \ + --epoch 30 \ + --avg 15 \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +./add_alignments.sh +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +./conformer_ctc3/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --manifest-dir data/fbank_ali +Note: It supports calculating symbol delay with following decoding methods: + - ctc-decoding + - 1best +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + convert_timestamp, + get_texts, + make_pad_mask, + parse_bpe_start_end_pairs, + parse_fsa_timestamps_and_texts, + setup_logger, + store_transcripts_and_timestamps, + str2bool, + write_error_stats_with_timestamps, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless4/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (0) ctc-greedy-search. It uses a sentence piece model, + i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def ctc_greedy_search( + ctc_probs: torch.Tensor, + nnet_output_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Apply CTC greedy search + Args: + ctc_probs (torch.Tensor): + (batch, max_len, feat_dim) + nnet_output_lens (torch.Tensor): + (batch, ) + sp: + The BPE model. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.squeeze(2) # (B, maxlen) + mask = make_pad_mask(nnet_output_lens) + topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + + def get_first_tokens(tokens: List[int]) -> List[bool]: + is_first_token = [] + first_tokens = [] + for t in range(len(tokens)): + if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]): + is_first_token.append(True) + first_tokens.append(tokens[t]) + else: + is_first_token.append(False) + return first_tokens, is_first_token + + utt_time_pairs = [] + utt_words = [] + for utt in range(len(hyps)): + first_tokens, is_first_token = get_first_tokens(hyps[utt]) + all_tokens = sp.id_to_piece(hyps[utt]) + index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token) + words = sp.decode(first_tokens).split() + assert len(index_pairs) == len(words), ( + len(index_pairs), + len(words), + all_tokens, + ) + start = convert_timestamp( + frames=[i[0] for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + utt_words.append(words) + + return utt_time_pairs, utt_words + + +def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: + # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + time: List[Tuple[int, int]] = [] + cur = 0 + start, end = -1, -1 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + start = cur + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + if start != -1: + end = cur + cur += 1 + if start != -1 and end != -1: + time.append((start, end)) + start, end = -1, -1 + return new_hyp, time + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + + nnet_output = model.get_ctc_output(encoder_out) + # nnet_output is (N, T, C) + + if params.decoding_method == "ctc-greedy-search": + timestamps, hyps = ctc_greedy_search( + ctc_probs=nnet_output, + nnet_output_lens=encoder_out_lens, + sp=bpe_model, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + key = "ctc-greedy-search" + return {key: (hyps, timestamps)} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + encoder_out_lens.cpu(), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, + sp=bpe_model, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + key = "ctc-decoding" + return {key: (hyps, timestamps)} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa + return {key: (hyps, timestamps)} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = f"no_rescore_hlg_scale_{params.hlg_scale}" + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, + word_table=word_table, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + return {key: (hyps, timestamps)} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + ans[lm_scale_str] = (hyps, timestamps) + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[ + str, + List[ + Tuple[ + str, + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], +]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + (aliword.start, aliword.end) + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + for name, (hyps, timestamps_hyp) in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[ + str, + List[ + Tuple[ + List[str], + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], + ], +): + test_set_wers = dict() + test_set_delays = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer, mean_delay, var_delay = write_error_stats_with_timestamps( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + with_end_time=True, + ) + test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + # sort according to the mean start symbol delay + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) + delays_info = ( + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}, variance: {}".format(key, val[0], val[1]), + file=f, + ) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format( + test_set_name + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.vocab_size = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_ctc_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py new file mode 100755 index 000000000..c5b95d981 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/export.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./conformer_ctc3/export.py \ + --exp-dir ./conformer_ctc3/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +It will generates the file: `jit_trace.pt`. + +(2) Export `model.state_dict()` + +./conformer_ctc3/export.py \ + --exp-dir ./conformer_ctc3/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `conformer_ctc3/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./conformer_ctc3/decode.py \ + --exp-dir ./conformer_ctc3/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --lang-dir data/lang_bpe_500 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless4/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + if params.streaming_model: + assert params.causal_convolution + + logging.info(params) + + logging.info("About to create model") + model = get_ctc_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace: + # TODO: will support streaming mode + assert not params.streaming_model + convert_scaled_to_non_scaled(model, inplace=True) + + logging.info("Using torch.jit.trace()") + + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + traced_model = torch.jit.trace(model, (x, x_lens)) + + filename = params.exp_dir / "jit_trace.pt" + traced_model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.trace()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py new file mode 100755 index 000000000..76db46cc8 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Usage (for non-streaming mode): + +(1) ctc-decoding +./conformer_ctc3/pretrained.py \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./conformer_ctc3/pretrained.py \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./conformer_ctc3/pretrained.py \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) whole-lattice-rescoring +./conformer_ctc3/pretrained.py \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + + logging.info(f"{params}") + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + nnet_output, _ = model(features, feature_lengths) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/lstmp.py b/egs/librispeech/ASR/conformer_ctc3/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/model.py b/egs/librispeech/ASR/conformer_ctc3/model.py new file mode 100644 index 000000000..f56df2006 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/model.py @@ -0,0 +1,122 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + + +class CTCModel(nn.Module): + """It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf + "Connectionist Temporal Classification: Labelling Unsegmented + Sequence Data with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + encoder_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + encoder_dim: + The feature embedding dimension. + vocab_size: + The vocabulary size. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder = encoder + self.ctc_output_module = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size), + ) + + def get_ctc_output( + self, + encoder_out: torch.Tensor, + delay_penalty: float = 0.0, + blank_threshold: float = 0.99, + ): + """Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty. + We first split utterance into sub-utterances according to the + blank probs, and then add sawtooth-like "blank-bonus" values to + the blank probs. + See https://github.com/k2-fsa/icefall/pull/669 for details. + + Args: + encoder_out: + A tensor with shape of (N, T, C). + delay_penalty: + A constant used to scale the delay penalty score. + blank_threshold: + The threshold used to split utterance into sub-utterances. + """ + output = self.ctc_output_module(encoder_out) + log_prob = nn.functional.log_softmax(output, dim=-1) + + if self.training and delay_penalty > 0: + T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device) + # split into sub-utterances using the blank-id + mask = log_prob[:, :, 0] >= math.log(blank_threshold) # (B, T) + mask[:, 0] = True + cummax_out = (T_arange * mask).cummax(dim=-1)[0] # (B, T) + # the sawtooth "blank-bonus" value + penalty = T_arange - cummax_out # (B, T) + penalty_all = torch.zeros_like(log_prob) + penalty_all[:, :, 0] = delay_penalty * penalty + # apply latency penalty on probs + log_prob = log_prob + penalty_all + + return log_prob + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + warmup: float = 1.0, + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + delay_penalty: + A constant used to scale the delay penalty score. + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(encoder_out_lens > 0) + nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty) + return nnet_output, encoder_out_lens diff --git a/egs/librispeech/ASR/conformer_ctc3/optim.py b/egs/librispeech/ASR/conformer_ctc3/optim.py new file mode 120000 index 000000000..e2deb4492 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py new file mode 100755 index 000000000..880945ea0 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Usage (for non-streaming mode): + +(1) ctc-decoding +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(2) 1best +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(3) nbest-rescoring +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(4) whole-lattice-rescoring +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("About to create model") + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=features, + x_lens=feature_lengths, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + nnet_output = model.get_ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling.py b/egs/librispeech/ASR/conformer_ctc3/scaling.py new file mode 120000 index 000000000..09d802cc4 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py new file mode 100755 index 000000000..b97b7eed8 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./conformer_ctc3/test_model.py +""" + +import torch + +from train import get_params, get_ctc_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False + + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + features = torch.randn(2, 100, 80) + feature_lengths = torch.full((2,), 100) + model(x=features, x_lens=feature_lengths) + + +def test_model_streaming(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = True + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = True + + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + features = torch.randn(2, 100, 80) + feature_lengths = torch.full((2,), 100) + encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths) + model.get_ctc_output(encoder_out) + + +def main(): + test_model() + test_model_streaming() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py new file mode 100755 index 000000000..2cd223945 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -0,0 +1,1109 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --max-duration 550 + +# train a streaming model +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 \ + --delay-penalty 0.0 +""" + +import argparse +import copy +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import CTCModel +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc3/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant used to scale the symbol delay penalty, + to encourage symbol emit earlier for streaming models. + It is almost the same as the `delay_penalty` in our `rnnt_loss`, See + https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + parser.add_argument( + "--nnet-delay-penalty", + type=float, + default=0.0, + help="""A constant to penalize symbol delay, which is applied on + the nnet_output after log-softmax. + We recommend using --delay-penalty instead. + See https://github.com/k2-fsa/icefall/pull/669 for details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for loss + "beam_size": 10, + "reduction": "none", + "use_double_scores": True, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, + ) + return encoder + + +def get_ctc_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + model = CTCModel( + encoder=encoder, + encoder_dim=params.encoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with torch.set_grad_enabled(is_training): + nnet_output, encoder_out_lens = model( + feature, + feature_lens, + warmup=warmup, + delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0, + ) + assert torch.all(encoder_out_lens > 0) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + ctc_loss_is_finite = torch.isfinite(ctc_loss) + if not torch.all(ctc_loss_is_finite): + logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") + ctc_loss = ctc_loss[ctc_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~ctc_loss_is_finite): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + loss = ctc_loss.sum() + + assert loss.requires_grad == is_training + + info = MetricsTracker() + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = supervision_segments[:, 2].sum().item() + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + params.vocab_size = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + need_repeat_flag=params.delay_penalty > 0, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 97c8d83a2..53e48eb13 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,9 +156,7 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -176,18 +174,14 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,9 +215,7 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) + src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -342,9 +334,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -360,9 +350,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -632,9 +620,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -702,31 +690,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -765,9 +744,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -779,9 +756,7 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -815,13 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -844,9 +815,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index fc9861489..74f6e73fa 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -478,9 +478,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +510,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -653,9 +649,7 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) + torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") return model.to(device) @@ -666,14 +660,22 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + # CAUTION: `test_sets` is for displaying only. # If you want to skip test-clean, you have to skip # it inside the for loop. That is, use # # if test_set == 'test-clean': continue - # test_sets = ["test-clean", "test-other"] - for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + test_dls = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, @@ -687,9 +689,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index 5c3e1222e..ad9415987 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,13 +25,9 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), + nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -115,17 +111,13 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) + torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) + self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index 937845d77..d0bb017dd 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling import torch +from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 08e680607..25d18076d 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,17 +1,16 @@ #!/usr/bin/env python3 import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index 011dadd73..100bc846a 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -30,29 +30,22 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -109,6 +102,41 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_mmi/exp-attn", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--use-pruned-intersect", + type=str2bool, + default=False, + help="""Whether to use `intersect_dense_pruned` to get denominator + lattice.""", + ) + return parser @@ -123,12 +151,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -173,8 +195,6 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_mmi/exp_500_with_attention"), - "lang_dir": Path("data/lang_bpe_500"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -193,15 +213,12 @@ def get_params() -> AttributeDict: "beam_size": 6, # will change it to 8 after some batches (see code) "reduction": "sum", "use_double_scores": True, - # "att_rate": 0.0, - # "num_decoder_layers": 0, "att_rate": 0.7, "num_decoder_layers": 6, # parameters for Noam "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, - "use_pruned_intersect": False, "den_scale": 1.0, # use alignments before this number of batches "use_ali_until": 13000, @@ -370,10 +387,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -673,7 +687,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -757,24 +771,40 @@ def run(rank, world_size, args): valid_ali = None librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -813,6 +843,7 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) world_size = args.world_size assert world_size >= 1 diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 9a5bdcce2..f9f80632e 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -30,29 +30,22 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) +from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool def get_parser(): @@ -109,6 +102,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_mmi/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + parser.add_argument( "--seed", type=int, @@ -116,6 +129,14 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--use-pruned-intersect", + type=str2bool, + default=False, + help="""Whether to use `intersect_dense_pruned` to get denominator + lattice.""", + ) + return parser @@ -130,12 +151,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -180,8 +195,6 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_mmi/exp_500"), - "lang_dir": Path("data/lang_bpe_500"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -202,13 +215,10 @@ def get_params() -> AttributeDict: "use_double_scores": True, "att_rate": 0.0, "num_decoder_layers": 0, - # "att_rate": 0.7, - # "num_decoder_layers": 6, # parameters for Noam "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, - "use_pruned_intersect": False, "den_scale": 1.0, # use alignments before this number of batches "use_ali_until": 13000, @@ -377,10 +387,7 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if ( - params.batch_idx_train > params.use_ali_until - and params.beam_size < 8 - ): + if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -764,25 +771,41 @@ def run(rank, world_size, args): valid_ali = None librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if ( - params.batch_idx_train >= params.use_ali_until - and train_ali is not None - ): + if params.batch_idx_train >= params.use_ali_until and train_ali is not None: # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -821,6 +844,7 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) world_size = args.world_size assert world_size >= 1 diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 68a4ff65c..2542d9abe 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,9 +148,7 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) + self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -182,9 +180,7 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) + encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -274,9 +270,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -341,9 +335,7 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -616,9 +608,7 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class PositionalEncoding(nn.Module): @@ -887,9 +877,7 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -910,9 +898,7 @@ def encoder_padding_mask( return mask -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 620d69a19..7be3299f3 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -440,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -461,10 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -506,9 +491,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +523,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +552,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8ca7d5568..91f50cf67 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -551,9 +533,7 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query( - torch.cat([right_context, utterance, summary]) - ) + query = self.emb_to_query(torch.cat([right_context, utterance, summary])) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -564,16 +544,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -588,9 +564,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection outputs = self.out_proj(attention) @@ -672,12 +646,7 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - ( - output_right_context_utterance, - output_memory, - _, - _, - ) = self._forward_impl( + (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( utterance, right_context, summary, @@ -947,13 +916,9 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -992,14 +957,10 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - ) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) summary = summary[:1] else: - summary = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) ( output_right_context_utterance, output_memory, @@ -1014,9 +975,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1151,11 +1110,7 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - ( - src_att, - output_memory, - attn_cache, - ) = self._apply_attention_module_infer( + (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1295,9 +1250,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1316,9 +1269,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1479,9 +1430,7 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1 - ] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1643,12 +1592,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1693,17 +1638,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1766,9 +1705,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 4930881ea..09a3e96b0 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -136,8 +136,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -181,9 +180,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +209,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -279,9 +278,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index 9494e1fc1..c211b215e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,14 +68,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 61dbe8658..e5a7c7116 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -761,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -781,10 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -831,9 +813,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +847,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +876,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c07d8f76b..6bb5505aa 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,16 +964,16 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 98b8290b5..d022d463e 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -215,8 +215,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -301,10 +298,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -427,9 +421,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -440,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -461,10 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -506,9 +491,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -540,9 +523,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -569,9 +552,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index f16f5acc7..3cedf99b6 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,7 +35,6 @@ from scaling import ( from icefall.utils import make_pad_mask - LOG_EPSILON = math.log(1e-10) @@ -127,9 +126,7 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack( - attn_caches[li][si], dim=1 - ) + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) conv_caches = [] for layer in state_list[0][1]: @@ -268,9 +265,7 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange( - self.chunk_length, self.chunk_length + self.cache_size - ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -284,9 +279,7 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context( - self, right_context: torch.Tensor, B: int - ) -> torch.Tensor: + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: """ Args: right_context: @@ -337,12 +330,8 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros( - B, D, self.cache_size, device=x.device, dtype=x.dtype - ) - pad_utterance = torch.cat( - [cache, utterance], dim=2 - ) # (B, D, cache + U) + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -355,9 +344,7 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context( - right_context, B - ) # (B, D, R) + right_context = self._merge_right_context(right_context, B) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -458,8 +445,7 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple of" - f"nhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})." ) self.embed_dim = embed_dim @@ -469,9 +455,7 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -513,9 +497,7 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view( - B, self.nhead, Q, -1 - ) + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -561,16 +543,12 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat( - [value[: M + R], left_context_val, value[M + R :]] - ) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous() - .view(-1, B * self.nhead, self.head_dim) - .transpose(0, 1) + tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -585,9 +563,7 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = ( - attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) - ) + attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -905,13 +881,11 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:-1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -948,18 +922,12 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute( - 2, 0, 1 - )[:1, :, :] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :1, :, : + ] else: - memory = torch.empty(0).to( - dtype=utterance.dtype, device=utterance.device - ) - ( - output_right_context_utterance, - next_key, - next_val, - ) = self.attention.infer( + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val,) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -967,9 +935,7 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache( - next_key, next_val, memory, attn_cache - ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) return output_right_context_utterance, attn_cache def forward( @@ -1226,9 +1192,7 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil( - (T - self.right_context_length) / self.chunk_length - ) + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1247,9 +1211,7 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths( - self, chunk_idx: int, U: int - ) -> List[int]: + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1549,12 +1511,8 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), - torch.zeros( - self.left_context_length, self.d_model, device=device - ), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), ] for _ in range(self.num_encoder_layers) ] @@ -1599,17 +1557,11 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if ( - left_context_length != 0 - and left_context_length % subsampling_factor != 0 - ): + if left_context_length != 0 and left_context_length % subsampling_factor != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if ( - right_context_length != 0 - and right_context_length % subsampling_factor != 0 - ): + if right_context_length != 0 and right_context_length % subsampling_factor != 0: raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1672,9 +1624,7 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder( - x, x_lens, warmup=warmup - ) # (T, N, C) + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py new file mode 100644 index 000000000..4a844b79f --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -0,0 +1,1821 @@ +# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# It is modified based on +# 1) https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py # noqa +# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) + +from icefall.utils import make_pad_mask + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] +) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]: + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the attention caches of a batch of utterance. + ``states[1]`` is the convolution caches of a batch of utterance. + ``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elements of i-th utterance. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa + """ + + attn_caches, conv_caches = states + batch_size = conv_caches[0].size(0) + num_layers = len(attn_caches) + + list_attn_caches = [None] * batch_size + for i in range(batch_size): + list_attn_caches[i] = [[] for _ in range(num_layers)] + for li, layer in enumerate(attn_caches): + for s in layer: + s_list = s.unbind(dim=1) + for bi, b in enumerate(list_attn_caches): + b[li].append(s_list[bi]) + + list_conv_caches = [None] * batch_size + for i in range(batch_size): + list_conv_caches[i] = [None] * num_layers + for li, layer in enumerate(conv_caches): + c_list = layer.unbind(dim=0) + for bi, b in enumerate(list_conv_caches): + b[li] = c_list[bi] + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [list_attn_caches[i], list_conv_caches[i]] + + return ans + + +def stack_states( + state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]] +) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]: + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + ``states[i]`` is a tuple of 2 elements of i-th utterance. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + + attn_caches = [] + for layer in state_list[0][0]: + if batch_size > 1: + # Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa + attn_caches.append([[s] for s in layer]) + else: + attn_caches.append([s.unsqueeze(1) for s in layer]) + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states[0]): + for si, s in enumerate(layer): + attn_caches[li][si].append(s) + if b == batch_size - 1: + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + + conv_caches = [] + for layer in state_list[0][1]: + if batch_size > 1: + # Note: We will stack conv_caches[layer][] later to get conv_caches[layer] # noqa + conv_caches.append([layer]) + else: + conv_caches.append(layer.unsqueeze(0)) + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states[1]): + conv_caches[li].append(layer) + if b == batch_size - 1: + conv_caches[li] = torch.stack(conv_caches[li], dim=0) + + return [attn_caches, conv_caches] + + +class ConvolutionModule(nn.Module): + """ConvolutionModule. + + Modified from https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa + + Args: + chunk_length (int): + Length of each chunk. + right_context_length (int): + Length of right context. + channels (int): + The number of input channels and output channels of conv layers. + kernel_size (int): + Kernerl size of conv layers. + bias (bool): + Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + chunk_length: int, + right_context_length: int, + channels: int, + kernel_size: int, + bias: bool = True, + is_pnnx: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super().__init__() + self.is_pnnx = is_pnnx + # kernerl_size should be an odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.chunk_length = chunk_length + self.right_context_length = right_context_length + self.channels = channels + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # After pointwise_conv1 we put x through a gated linear unit + # (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in + # the range 1 to 4, but sometimes, for some reason, for layer 0 the rms + # ends up being very large, between 50 and 100 for different channels. + # This will cause very peaky and sparse derivatives for the sigmoid + # gating function, which will tend to make the loss function not learn + # effectively. (for most layers the average absolute values are in the + # range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for + # different layers, which likely breaks down as 0.5 for the "linear" + # half and 0.2 to 0.3 for the part that goes into the sigmoid. + # The idea is that if we constrain the rms values to a reasonable range + # via a constraint of max_abs=10.0, it will be in a better position to + # start learning something, i.e. to latch onto the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + # make it causal by padding cached (kernel_size - 1) frames on the left + self.cache_size = kernel_size - 1 + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def _split_right_context( + self, + pad_utterance: torch.Tensor, + right_context: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + pad_utterance: + Its shape is (cache_size + U, B, D). + right_context: + Its shape is (R, B, D). + + Returns: + Right context segments padding with corresponding context. + Its shape is (num_segs * B, D, cache_size + right_context_length). + """ + U_, B, D = pad_utterance.size() + R = right_context.size(0) + assert self.right_context_length != 0 + assert R % self.right_context_length == 0 + num_chunks = R // self.right_context_length + right_context = right_context.reshape( + num_chunks, self.right_context_length, B, D + ) + right_context = right_context.permute(0, 2, 1, 3).reshape( + num_chunks * B, self.right_context_length, D + ) + + intervals = torch.arange( + 0, self.chunk_length * (num_chunks - 1), self.chunk_length + ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + indexes = intervals.unsqueeze(1) + first.unsqueeze(0) + indexes = torch.cat( + [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] + ) + padding = pad_utterance[indexes] # (num_chunks, cache_size, B, D) + padding = padding.permute(0, 2, 1, 3).reshape( + num_chunks * B, self.cache_size, D + ) + + pad_right_context = torch.cat([padding, right_context], dim=1) + # (num_chunks * B, cache_size + right_context_length, D) + return pad_right_context.permute(0, 2, 1) + + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + """ + Args: + right_context: + Right context segments. + It shape is (num_segs * B, D, right_context_length). + B: + Batch size. + + Returns: + A tensor of shape (B, D, R), where + R = num_segs * right_context_length. + """ + right_context = right_context.reshape( + -1, B, self.channels, self.right_context_length + ) + right_context = right_context.permute(1, 2, 0, 3) + right_context = right_context.reshape(B, self.channels, -1) + return right_context + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Causal convolution module. + + Args: + utterance (torch.Tensor): + Utterance tensor of shape (U, B, D). + right_context (torch.Tensor): + Right context tensor of shape (R, B, D). + + Returns: + A tuple of 2 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). + """ + U, B, D = utterance.size() + R, _, _ = right_context.size() + + # point-wise conv and GLU mechanism + x = torch.cat([right_context, utterance], dim=0) # (R + U, B, D) + x = x.permute(1, 2, 0) # (B, D, R + U) + x = self.pointwise_conv1(x) # (B, 2 * D, R + U) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (B, D, R + U) + utterance = x[:, :, R:] # (B, D, U) + right_context = x[:, :, :R] # (B, D, R) + + # make causal convolution + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + + # depth-wise conv on utterance + utterance = self.depthwise_conv(pad_utterance) # (B, D, U) + + if self.right_context_length > 0: + # depth-wise conv on right_context + pad_right_context = self._split_right_context( + pad_utterance.permute(2, 0, 1), right_context.permute(2, 0, 1) + ) # (num_segs * B, D, cache_size + right_context_length) + right_context = self.depthwise_conv( + pad_right_context + ) # (num_segs * B, D, right_context_length) + right_context = self._merge_right_context(right_context, B) # (B, D, R) + + x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) + x = self.deriv_balancer2(x) + x = self.activation(x) + + # point-wise conv + x = self.pointwise_conv2(x) # (B, D, R + U) + + right_context = x[:, :, :R] # (B, D, R) + utterance = x[:, :, R:] # (B, D, U) + return ( + utterance.permute(2, 0, 1), + right_context.permute(2, 0, 1), + ) + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + cache: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Causal convolution module applied on both utterance and right_context. + + Args: + utterance (torch.Tensor): + Utterance tensor of shape (U, B, D). + right_context (torch.Tensor): + Right context tensor of shape (R, B, D). + cache (torch.Tensor, optional): + Cached tensor for left padding of shape (B, D, cache_size). + + Returns: + A tuple of 3 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). + - updated cache tensor of shape (B, D, cache_size). + """ + if self.is_pnnx is False: + U, B, D = utterance.size() + R, _, _ = right_context.size() + else: + U = self.chunk_length + B = 1 + D = self.channels + R = self.right_context_length + + # point-wise conv + x = torch.cat([utterance, right_context], dim=0) # (U + R, B, D) + x = x.permute(1, 2, 0) # (B, D, U + R) + x = self.pointwise_conv1(x) # (B, 2 * D, U + R) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (B, D, U + R) + + # make causal convolution + assert cache.shape == (B, D, self.cache_size), cache.shape + x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R) + # update cache + new_cache = x[:, :, -R - self.cache_size : -R] + + # 1-D depth-wise conv + x = self.depthwise_conv(x) # (B, D, U + R) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + # point-wise conv + x = self.pointwise_conv2(x) # (B, D, U + R) + + utterance = x[:, :, :U] # (B, D, U) + right_context = x[:, :, U:] # (B, D, R) + return ( + utterance.permute(2, 0, 1), + right_context.permute(2, 0, 1), + new_cache, + ) + + +class EmformerAttention(nn.Module): + r"""Emformer layer attention module. + + Args: + embed_dim (int): + Embedding dimension. + nhead (int): + Number of attention heads in each Emformer layer. + dropout (float, optional): + Dropout probability. (Default: 0.0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + embed_dim: int, + nhead: int, + left_context_length: int, + chunk_length: int, + right_context_length: int, + memory_size: int, + dropout: float = 0.0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + is_pnnx: bool = True, + ): + super().__init__() + self.is_pnnx = is_pnnx + + if embed_dim % nhead != 0: + raise ValueError( + f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})." + ) + + self.embed_dim = embed_dim + self.nhead = nhead + self.tanh_on_mem = tanh_on_mem + self.negative_inf = negative_inf + self.head_dim = embed_dim // nhead + self.dropout = dropout + + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.memory_size = memory_size + + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + def _gen_attention_probs( + self, + attention_weights: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Given the entire attention weights, mask out unecessary connections + and optionally with padding positions, to obtain underlying chunk-wise + attention probabilities. + + B: batch size; + Q: length of query; + KV: length of key and value. + + Args: + attention_weights (torch.Tensor): + Attention weights computed on the entire concatenated tensor + with shape (B * nhead, Q, KV). + attention_mask (torch.Tensor): + Mask tensor where chunk-wise connections are filled with `False`, + and other unnecessary connections are filled with `True`, + with shape (Q, KV). + padding_mask (torch.Tensor, optional): + Mask tensor where the padding positions are fill with `True`, + and other positions are filled with `False`, with shapa `(B, KV)`. + + Returns: + A tensor of shape (B * nhead, Q, KV). + """ + attention_weights_float = attention_weights.float() + attention_weights_float = attention_weights_float.masked_fill( + attention_mask.unsqueeze(0), self.negative_inf + ) + if padding_mask is not None: + Q = attention_weights.size(1) + B = attention_weights.size(0) // self.nhead + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + self.negative_inf, + ) + attention_weights_float = attention_weights_float.view( + B * self.nhead, Q, -1 + ) + + attention_probs = nn.functional.softmax( + attention_weights_float, dim=-1 + ).type_as(attention_weights) + + attention_probs = nn.functional.dropout( + attention_probs, p=self.dropout, training=self.training + ) + return attention_probs + + def _forward_impl( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + left_context_key: Optional[torch.Tensor] = None, + left_context_val: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Underlying chunk-wise attention implementation.""" + if self.is_pnnx is False: + U, B, _ = utterance.size() + R = right_context.size(0) + M = memory.size(0) + else: + U = self.chunk_length + B = 1 + R = self.right_context_length + M = self.memory_size + L = self.left_context_length + + scaling = float(self.head_dim) ** -0.5 + + # compute query with [right_context, utterance]. + query = self.emb_to_query(torch.cat([right_context, utterance])) + # compute key and value with [memory, right_context, utterance]. + key, value = self.emb_to_key_value( + torch.cat([memory, right_context, utterance]) + ).chunk(chunks=2, dim=2) + + if left_context_key is not None and left_context_val is not None: + # now compute key and value with + # [memory, right context, left context, uttrance] + # this is used in inference mode + key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + + # Q = query.size(0) + Q = U + R + + # KV = key.size(0) + + if self.is_pnnx is True: + reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) + reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( + 1, 0, 2 + ) + reshaped_value = value.view( + M + R + U + L, self.nhead, self.head_dim + ).permute(1, 0, 2) + else: + reshaped_query, reshaped_key, reshaped_value = [ + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) + for tensor in [query, key, value] + ] # (B * nhead, Q or KV, head_dim) + if self.is_pnnx is True: + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.permute(0, 2, 1) + ) # (B * nhead, Q, KV) + else: + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.transpose(1, 2) + ) # (B * nhead, Q, KV) + + # compute attention probabilities + if False: + attention_probs = self._gen_attention_probs( + attention_weights, attention_mask, padding_mask + ) + else: + attention_probs = nn.functional.softmax(attention_weights, dim=-1) + + # compute attention outputs + attention = torch.bmm(attention_probs, reshaped_value) + assert attention.shape == (B * self.nhead, Q, self.head_dim) + if self.is_pnnx is True: + attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) + # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) + # We have to change InnerProduct in ncnn to ignore the extra dim below + attention = attention.unsqueeze(1) + else: + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) + + # apply output projection + output_right_context_utterance = self.out_proj(attention) + # The return shape of output_right_context_utterance is (10, 1, 512) + + return output_right_context_utterance, key, value + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO: Modify docs. + """Forward pass for training and validation mode. + + B: batch size; + D: embedding dimension; + R: length of the hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors. + + It computes a `big` attention matrix on full utterance and + then utilizes a pre-computed mask to simulate chunk-wise attention. + + It concatenates three blocks: hard-copied right contexts, + and full utterance, as a `big` block, + to compute the query tensor: + query = [right_context, utterance], + with length Q = R + U. + It concatenates the three blocks: memory vectors, + hard-copied right contexts, and full utterance as another `big` block, + to compute the key and value tensors: + key & value = [memory, right_context, utterance], + with length KV = M + R + U. + Attention scores is computed with above `big` query and key. + + Then the underlying chunk-wise attention is obtained by applying + the attention mask. Suppose + c_i: chunk at index i; + r_i: right context that c_i can use; + l_i: left context that c_i can use; + m_i: past memory vectors from previous layer that c_i can use; + The target chunk-wise attention is: + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key) + + Args: + utterance (torch.Tensor): + Full utterance frames, with shape (U, B, D). + right_context (torch.Tensor): + Hard-copied right context frames, with shape (R, B, D), + where R = num_chunks * right_context_length + memory (torch.Tensor): + Memory elements, with shape (M, B, D), where M = num_chunks - 1. + It is an empty tensor without using memory. + attention_mask (torch.Tensor): + Pre-computed attention mask to simulate underlying chunk-wise + attention, with shape (Q, KV). + padding_mask (torch.Tensor): + Padding mask of key tensor, with shape (B, KV). + + Returns: + Output of right context and utterance, with shape (R + U, B, D). + """ + output_right_context_utterance, _, _ = self._forward_impl( + utterance, + right_context, + memory, + attention_mask, + padding_mask=padding_mask, + ) + return output_right_context_utterance + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + left_context_key: torch.Tensor, + left_context_val: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right context; + U: length of utterance, i.e., current chunk; + L: length of cached left context; + M: length of cached memory vectors. + + It concatenates the right context and utterance (i.e., current chunk) + of current chunk, to compute the query tensor: + query = [right_context, utterance], + with length Q = R + U. + It concatenates the memory vectors, right context, left context, and + current chunk, to compute the key and value tensors: + key & value = [memory, right_context, left_context, utterance], + with length KV = M + R + L + U. + + The chunk-wise attention is: + chunk, right context (in query) -> + left context, chunk, right context, memory vectors (in key). + + Args: + utterance (torch.Tensor): + Current chunk frames, with shape (U, B, D), where U = chunk_length. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D), + where R = right_context_length. + memory (torch.Tensor): + Memory vectors, with shape (M, B, D), or empty tensor. + left_context_key (torch,Tensor): + Cached attention key of left context from preceding computation, + with shape (L, B, D). + left_context_val (torch.Tensor): + Cached attention value of left context from preceding computation, + with shape (L, B, D). + padding_mask (torch.Tensor): + Padding mask of key tensor, with shape (B, KV). + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - attention key of left context and utterance, which would be cached + for next computation, with shape (L + U, B, D). + - attention value of left context and utterance, which would be + cached for next computation, with shape (L + U, B, D). + """ + if self.is_pnnx is False: + U = utterance.size(0) + R = right_context.size(0) + L = left_context_key.size(0) + M = memory.size(0) + else: + U = self.chunk_length + R = self.right_context_length + L = self.left_context_length + M = self.memory_size + + # query = [right context, utterance] + Q = R + U + # key, value = [memory, right context, left context, utterance] + KV = M + R + L + U + attention_mask = torch.zeros(Q, KV).to( + dtype=torch.bool, device=utterance.device + ) + + output_right_context_utterance, key, value = self._forward_impl( + utterance, + right_context, + memory, + attention_mask, + padding_mask=padding_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + return ( + output_right_context_utterance, + key[M + R :], + value[M + R :], + ) + + +class EmformerEncoderLayer(nn.Module): + """Emformer layer that constitutes Emformer. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads. + dim_feedforward (int): + Hidden layer dimension of feedforward network. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (Default: 0.0) + layer_dropout (float, optional): + Layer dropout probability. (Default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (Default: 0) + right_context_length (int, optional): + Length of right context. (Default: 0) + memory_size (int, optional): + Number of memory elements to use. (Default: 0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int, + chunk_length: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + left_context_length: int = 0, + right_context_length: int = 0, + memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + is_pnnx: bool = True, + ): + super().__init__() + + self.attention = EmformerAttention( + embed_dim=d_model, + nhead=nhead, + left_context_length=left_context_length, + chunk_length=chunk_length, + memory_size=memory_size, + right_context_length=right_context_length, + dropout=dropout, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + is_pnnx=is_pnnx, + ) + self.summary_op = nn.AvgPool1d( + kernel_size=chunk_length, stride=chunk_length, ceil_mode=True + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule( + chunk_length, + right_context_length, + d_model, + cnn_module_kernel, + is_pnnx=is_pnnx, + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean + # (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + self.layer_dropout = layer_dropout + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.memory_size = memory_size + self.d_model = d_model + self.use_memory = memory_size > 0 + + def _update_attn_cache( + self, + next_key: torch.Tensor, + next_val: torch.Tensor, + memory: torch.Tensor, + attn_cache: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Update cached attention state: + 1) output memory of current chunk in the lower layer; + 2) attention key and value in current chunk's computation, which would + be reused in next chunk's computation. + """ + # attn_cache[0].shape (self.memory_size, 1, 512) + # memory.shape (1, 1, 512) + # attn_cache[1].shape (self.left_context_length, 1, 512) + # attn_cache[2].shape (self.left_context_length, 1, 512) + # next_key.shape (self.left_context_length + self.right_context_utterance, 1, 512) + # next_value.shape (self.left_context_length + self.right_context_utterance, 1, 512) + new_memory = torch.cat([attn_cache[0], memory]) + # TODO(fangjun): Remove torch.cat + # new_key = torch.cat([attn_cache[1], next_key]) + # new_val = torch.cat([attn_cache[2], next_val]) + attn_cache[0] = new_memory[1:] + attn_cache[1] = next_key[-self.left_context_length :] + attn_cache[2] = next_val[-self.left_context_length :] + return attn_cache + + def _apply_conv_module_forward( + self, + right_context_utterance: torch.Tensor, + R: int, + ) -> torch.Tensor: + """Apply convolution module in training and validation mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + utterance, right_context = self.conv_module(utterance, right_context) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance + + def _apply_conv_module_infer( + self, + right_context_utterance: torch.Tensor, + R: int, + conv_cache: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply convolution module on utterance in inference mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + utterance, right_context, conv_cache = self.conv_module.infer( + utterance, right_context, conv_cache + ) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance, conv_cache + + def _apply_attention_module_forward( + self, + right_context_utterance: torch.Tensor, + R: int, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply attention module in training and validation mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + + if self.use_memory: + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] + else: + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + output_right_context_utterance = self.attention( + utterance=utterance, + right_context=right_context, + memory=memory, + attention_mask=attention_mask, + padding_mask=padding_mask, + ) + + return output_right_context_utterance + + def _apply_attention_module_infer( + self, + right_context_utterance: torch.Tensor, + R: int, + attn_cache: List[torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Apply attention module in inference mode. + 1) Unpack cached states including: + - memory from previous chunks; + - attention key and value of left context from preceding + chunk's compuation; + 2) Apply attention computation; + 3) Update cached attention states including: + - memory of current chunk; + - attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + """ + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + + pre_memory = attn_cache[0] + left_context_key = attn_cache[1] + left_context_val = attn_cache[2] + + if self.use_memory: + memory = torch.mean(utterance, dim=0, keepdim=True) + + # memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + # :1, :, : + # ] + else: + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val) = self.attention.infer( + utterance=utterance, + right_context=right_context, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + padding_mask=padding_mask, + ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + return output_right_context_utterance, attn_cache + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Forward pass for training and validation mode. + + B: batch size; + D: embedding dimension; + R: length of hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying attention module, + with shape (Q, KV), where Q = R + U, KV = M + R + U. + padding_mask (torch.Tensor): + Padding mask of ker tensor, with shape (B, KV). + + Returns: + A tuple containing 2 tensors: + - output utterance, with shape (U, B, D). + - output right context, with shape (R, B, D). + """ + R = right_context.size(0) + src = torch.cat([right_context, utterance]) + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # emformer attention module + src_att = self._apply_attention_module_forward( + src, R, attention_mask, padding_mask=padding_mask + ) + src = src + self.dropout(src_att) + + # convolution module + src_conv = self._apply_conv_module_forward(src, R) + src = src + self.dropout(src_conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + output_utterance = src[R:] + output_right_context = src[:R] + return output_utterance, output_right_context + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + cache: List[torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + attn_cache (List[torch.Tensor]): + Cached attention tensors generated in preceding computation, + including memory, key and value of left context. + conv_cache (torch.Tensor, optional): + Cache tensor of left context for causal convolution. + padding_mask (torch.Tensor): + Padding mask of ker tensor. + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output attention cache; + - output convolution cache. + """ + R = self.right_context_length + src = torch.cat([right_context, utterance]) + attn_cache = cache[:3] + conv_cache = cache[3] + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # emformer attention module + src_att, attn_cache = self._apply_attention_module_infer( + src, R, attn_cache, padding_mask=padding_mask + ) + src = src + self.dropout(src_att) + + # convolution module + src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache) + src = src + self.dropout(src_conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + output_utterance = src[R:] + output_right_context = src[:R] + return (output_utterance, output_right_context, attn_cache + [conv_cache]) + + +def _gen_attention_mask_block( + col_widths: List[int], + col_mask: List[bool], + num_rows: int, + device: torch.device, +) -> torch.Tensor: + assert len(col_widths) == len( + col_mask + ), "Length of col_widths must match that of col_mask" + + mask_block = [ + torch.ones(num_rows, col_width, device=device) + if is_ones_col + else torch.zeros(num_rows, col_width, device=device) + for col_width, is_ones_col in zip(col_widths, col_mask) + ] + return torch.cat(mask_block, dim=1) + + +class EmformerEncoder(nn.Module): + """Implements the Emformer architecture introduced in + *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency + Streaming Speech Recognition* + [:footcite:`shi2021emformer`]. + + In this model, the memory bank computation is simplifed, using the averaged + value of each chunk as its memory vector. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads in each emformer layer. + dim_feedforward (int): + Hidden layer dimension of each emformer layer's feedforward network. + num_encoder_layers (int): + Number of emformer layers to instantiate. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (default: 0.0) + layer_dropout (float, optional): + Layer dropout probability. (default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (default: 0) + right_context_length (int, optional): + Length of right context. (default: 0) + memory_size (int, optional): + Number of memory elements to use. (default: 0) + tanh_on_mem (bool, optional): + If ``true``, applies tanh to memory elements. (default: ``false``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (default: -1e8) + """ + + def __init__( + self, + chunk_length: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + left_context_length: int = 0, + right_context_length: int = 0, + memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + is_pnnx: bool = True, + ): + super().__init__() + + assert ( + chunk_length - 1 + ) & chunk_length == 0, "chunk_length should be a power of 2." + self.shift = int(math.log(chunk_length, 2)) + + self.use_memory = memory_size > 0 + + self.emformer_layers = nn.ModuleList( + [ + EmformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + chunk_length=chunk_length, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length, + right_context_length=right_context_length, + memory_size=memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + is_pnnx=is_pnnx, + ) + for layer_idx in range(num_encoder_layers) + ] + ) + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.memory_size = memory_size + self.cnn_module_kernel = cnn_module_kernel + + def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: + """Hard copy each chunk's right context and concat them.""" + T = x.shape[0] + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + # first (num_chunks - 1) right context block + intervals = torch.arange( + 0, self.chunk_length * (num_chunks - 1), self.chunk_length + ) + first = torch.arange( + self.chunk_length, self.chunk_length + self.right_context_length + ) + indexes = intervals.unsqueeze(1) + first.unsqueeze(0) + # cat last right context block + indexes = torch.cat( + [ + indexes, + torch.arange(T - self.right_context_length, T).unsqueeze(0), + ] + ) + right_context_blocks = x[indexes.reshape(-1)] + return right_context_blocks + + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + """Calculate column widths (key, value) in attention mask for the + chunk_idx chunk.""" + num_chunks = math.ceil(U / self.chunk_length) + rc = self.right_context_length + lc = self.left_context_length + rc_start = chunk_idx * rc + rc_end = rc_start + rc + chunk_start = max(chunk_idx * self.chunk_length - lc, 0) + chunk_end = min((chunk_idx + 1) * self.chunk_length, U) + R = rc * num_chunks + + if self.use_memory: + m_start = max(chunk_idx - self.memory_size, 0) + M = num_chunks - 1 + col_widths = [ + m_start, # before memory + chunk_idx - m_start, # memory + M - chunk_idx, # after memory + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + else: + col_widths = [ + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + + return col_widths + + def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor: + """Generate attention mask to simulate underlying chunk-wise attention + computation, where chunk-wise connections are filled with `False`, + and other unnecessary connections beyond chunk are filled with `True`. + + R: length of hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors; + Q: length of attention query; + KV: length of attention key and value. + + The shape of attention mask is (Q, KV). + If self.use_memory is `True`: + query = [right_context, utterance]; + key, value = [memory, right_context, utterance]; + Q = R + U, KV = M + R + U. + Otherwise: + query = [right_context, utterance] + key, value = [right_context, utterance] + Q = R + U, KV = R + U. + + Suppose: + c_i: chunk at index i; + r_i: right context that c_i can use; + l_i: left context that c_i can use; + m_i: past memory vectors from previous layer that c_i can use; + The target chunk-wise attention is: + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key). + """ + U = utterance.size(0) + num_chunks = math.ceil(U / self.chunk_length) + + right_context_mask = [] + utterance_mask = [] + + if self.use_memory: + num_cols = 9 + # right context and utterance both attend to memory, right context, + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4, 7] for idx in range(num_cols) + ] + else: + num_cols = 6 + # right context and utterance both attend to right context and + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4] for idx in range(num_cols) + ] + masks_to_concat = [right_context_mask, utterance_mask] + + for chunk_idx in range(num_chunks): + col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) + + right_context_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + self.right_context_length, + utterance.device, + ) + right_context_mask.append(right_context_mask_block) + + utterance_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + min( + self.chunk_length, + U - chunk_idx * self.chunk_length, + ), + utterance.device, + ) + utterance_mask.append(utterance_mask_block) + + attention_mask = ( + 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) + ).to(torch.bool) + return attention_mask + + def _forward( + self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and validation mode. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + + Returns: + A tuple of 2 tensors: + - output utterance frames, with shape (U, B, D). + - output_lengths, with shape (B,), without containing the + right_context at the end. + """ + U = x.size(0) - self.right_context_length + + right_context = self._gen_right_context(x) + utterance = x[:U] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + attention_mask = self._gen_attention_mask(utterance) + + M = ( + right_context.size(0) // self.right_context_length - 1 + if self.use_memory + else 0 + ) + padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) + + output = utterance + for layer in self.emformer_layers: + output, right_context = layer( + output, + right_context, + attention_mask, + padding_mask=padding_mask, + warmup=warmup, + ) + + return output, output_lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward pass for streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa + Cached states containing: + - attn_caches: attention states from preceding chunk's computation, + where each element corresponds to each emformer layer + - conv_caches: left context for causal convolution, where each + element corresponds to each layer. + + Returns: + (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): + - output utterance frames, with shape (U, B, D). + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + """ + # lengths = chunk_length + right_context_length + utterance = x[: self.chunk_length] + right_context = x[self.chunk_length :] + # right_context_utterance = torch.cat([right_context, utterance]) + + output = utterance + output_states: List[torch.Tensor] = [] + for layer_idx, layer in enumerate(self.emformer_layers): + start = layer_idx * 4 + end = start + 4 + cache = states[start:end] + + (output, right_context, output_cache,) = layer.infer( + output, + right_context, + padding_mask=None, + cache=cache, + ) + output_states.extend(output_cache) + + return output, output_states + + @torch.jit.export + def init_states( + self, device: torch.device = torch.device("cpu") + ) -> List[torch.Tensor]: + """Create initial states.""" + # + states = [] + # layer0: attn cache, conv cache, 3 tensors + 1 tensor + # layer1: attn cache, conv cache, 3 tensors + 1 tensor + # layer2: attn cache, conv cache, 3 tensors + 1 tensor + # ... + # last layer: attn cache, conv cache, 3 tensors + 1 tensor + for i in range(self.num_encoder_layers): + states.append(torch.zeros(self.memory_size, 1, self.d_model, device=device)) + states.append( + torch.zeros(self.left_context_length, 1, self.d_model, device=device) + ) + states.append( + torch.zeros(self.left_context_length, 1, self.d_model, device=device) + ) + + states.append( + torch.zeros(1, self.d_model, self.cnn_module_kernel - 1, device=device) + ) + return states + + +class Emformer(EncoderInterface): + def __init__( + self, + num_features: int, + chunk_length: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 3, + left_context_length: int = 0, + right_context_length: int = 0, + memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + is_pnnx: bool = True, + ): + super().__init__() + + self.subsampling_factor = subsampling_factor + self.right_context_length = right_context_length + self.chunk_length = chunk_length + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + if chunk_length % subsampling_factor != 0: + raise NotImplementedError( + "chunk_length must be a mutiple of subsampling_factor." + ) + if left_context_length != 0 and left_context_length % subsampling_factor != 0: + raise NotImplementedError( + "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa + ) + if right_context_length != 0 and right_context_length % subsampling_factor != 0: + raise NotImplementedError( + "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa + ) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx) + self.is_pnnx = is_pnnx + + self.num_encoder_layers = num_encoder_layers + self.memory_size = memory_size + self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.left_context_length = left_context_length // subsampling_factor + self.right_context_length = right_context_length + self.subsampling_factor = subsampling_factor + + assert subsampling_factor == 4, subsampling_factor + pad_length = right_context_length + 2 * 4 + 3 + + # before subsampling + self.T = self.chunk_length + pad_length + + self.encoder = EmformerEncoder( + chunk_length=chunk_length // subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length // subsampling_factor, + right_context_length=right_context_length // subsampling_factor, + memory_size=memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + is_pnnx=is_pnnx, + ) + + def _forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + x_lens (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + (Tensor, Tensor): + - output embedding, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - output lengths, with shape (B,), without containing the + right_context at the end. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x_lens = (((x_lens - 1) >> 1) - 1) >> 1 + assert x.size(0) == x_lens.max().item() + + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + + output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + return output, output_lengths + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward pass for streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa + Cached states containing: + - past_lens: number of past frames for each sample in batch + - attn_caches: attention states from preceding chunk's computation, + where each element corresponds to each emformer layer + - conv_caches: left context for causal convolution, where each + element corresponds to each layer. + Returns: + (Tensor, Tensor): + - output embedding, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + """ + x = self.encoder_embed(x) + # drop the first and last frames + x = x[:, 1:-1, :] + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + + output, output_states = self.encoder.infer(x, states) + + output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + return output, output_states + + @torch.jit.export + def init_states( + self, device: torch.device = torch.device("cpu") + ) -> List[torch.Tensor]: + """Create initial states.""" + return self.encoder.init_states(device) + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + is_pnnx: bool = True, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx is True: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-1)//2-1)//2, ((idim-1)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py new file mode 100755 index 000000000..8fbb02f14 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +Usage: +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir ./conv_emformer_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + +cd ./conv_emformer_transducer_stateless2/exp +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + chunk_length = encoder_model.chunk_length # before subsampling + right_context_length = encoder_model.right_context_length # before subsampling + pad_length = right_context_length + 2 * 4 + 3 + s = f"chunk_length: {chunk_length}, " + s += f"right_context_length: {right_context_length}\n" + logging.info(s) + + T = chunk_length + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.init_states() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py new file mode 100755 index 000000000..ad0b45bd9 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model +from emformer import Emformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Emformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Emformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Emformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Emformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - a list of states (each layers has 4 states) + + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - a list of states (each layers has 4 states) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + memory_size = encoder_model.encoder.memory_size + cnn_module_kernel = encoder_model.encoder.cnn_module_kernel + chunk_length = encoder_model.encoder.chunk_length + right_context_length = encoder_model.encoder.right_context_length + encoder_dim = encoder_model.encoder.d_model + left_context_length = encoder_model.encoder.left_context_length + + T = encoder_model.encoder.T + + logging.info(f"num_encoder_layers={num_encoder_layers}") + logging.info(f"memory_size={memory_size}") + logging.info(f"cnn_module_kernel={cnn_module_kernel}") + logging.info(f"chunk_length={chunk_length}") + logging.info(f"right_context_length={right_context_length}") + logging.info(f"encoder_dim={encoder_dim}") + logging.info(f"left_context_length={left_context_length} (after subsampling)") + logging.info(f"T={T}") + + meta_data = { + "model_type": "conv-emformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(chunk_length), # 32 + "T": str(T), # 32 + "num_encoder_layers": str(num_encoder_layers), + "memory_size": str(memory_size), + "cnn_module_kernel": str(cnn_module_kernel), + "right_context_length": str(right_context_length), + "left_context_length": str(left_context_length), + "encoder_dim": str(encoder_dim), + } + logging.info(f"meta_data: {meta_data}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.init_states() + + # Each layer has 4 states + assert len(states) == num_encoder_layers * 4, (len(states), num_encoder_layers) + # layer 0: + # state0: (memory_size, 1, encoder_dim) + # state1: (left_context_length, 1, encoder_dim) + # state2: (left_context_length, 1, encoder_dim) + # state3: (1, encoder_dim, cnn_module_kernel-1) + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(s, name): + assert len(s) == 4, len(s) + logging.info(f"{name}_0.shape: {s[0].shape}") + input_names.append(f"{name}_0") + inputs[f"{name}_0"] = {1: "N"} + output_names.append(f"new_{name}_0") + + logging.info(f"{name}_1.shape: {s[1].shape}") + input_names.append(f"{name}_1") + inputs[f"{name}_1"] = {1: "N"} + output_names.append(f"new_{name}_1") + + logging.info(f"{name}_2.shape: {s[2].shape}") + input_names.append(f"{name}_2") + inputs[f"{name}_2"] = {1: "N"} + output_names.append(f"new_{name}_2") + + logging.info(f"{name}_3.shape: {s[3].shape}") + input_names.append(f"{name}_3") + inputs[f"{name}_3"] = {0: "N"} + output_names.append(f"new_{name}_3") + + for i in range(num_encoder_layers): + base_name = f"layer{i}" + s = states[i * 4 : (i + 1) * 4] + build_inputs_outputs(s, base_name) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.is_pnnx = False + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index ab15e0241..b53426c75 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -64,6 +64,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -136,8 +137,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -181,9 +181,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +210,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -259,6 +259,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not @@ -279,9 +280,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py new file mode 100755 index 000000000..1fe358c79 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir ./conv_emformer_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +./conv_emformer_transducer_stateless2/jit_pretrained.py \ + --encoder-model-filename ./conv_emformer_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt \ + --decoder-model-filename ./conv_emformer_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt \ + --joiner-model-filename ./conv_emformer_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence +from typing import Optional, List + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = encoder.chunk_length + right_context_length = encoder.right_context_length + + # Assume the subsampling factor is 4 + pad_length = right_context_length + 2 * 4 + 3 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + logging.info(f"right_context_length: {right_context_length}") + + states = encoder.init_states(device) + logging.info(f"num layers: {len(states)//4}") + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += chunk_length + frames = torch.cat(frames, dim=0).unsqueeze(0) + # TODO(fangjun): remove x_lens + x_lens = torch.tensor([T]) + encoder_out, _, states = encoder(frames, x_lens, states) + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..db92ac696 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./conv_emformer_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "conv-emformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + memory_size = int(encoder_meta["memory_size"]) + cnn_module_kernel = int(encoder_meta["cnn_module_kernel"]) + right_context_length = int(encoder_meta["right_context_length"]) + left_context_length = int(encoder_meta["left_context_length"]) + encoder_dim = int(encoder_meta["encoder_dim"]) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"memory_size: {memory_size}") + logging.info(f"cnn_module_kernel: {cnn_module_kernel}") + logging.info(f"left_context_length: {left_context_length} (after subsampling)") + logging.info(f"right_context_length: {right_context_length}") + logging.info(f"encoder_dim: {encoder_dim}") + + N = batch_size + + states = [] + for i in range(num_encoder_layers): + s0 = torch.zeros(memory_size, N, encoder_dim) + s1 = torch.zeros(left_context_length, N, encoder_dim) + s2 = torch.zeros(left_context_length, N, encoder_dim) + s3 = torch.zeros(N, encoder_dim, cnn_module_kernel - 1) + states.extend([s0, s1, s2, s3]) + + self.states = states + + self.segment = T + self.offset = decode_chunk_len + self.num_encoder_layers = num_encoder_layers + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(states: List[torch.Tensor], name: str): + for i in range(4): + if isinstance(states[i], torch.Tensor): + encoder_input[f"{name}_{i}"] = states[i].numpy() + else: + encoder_input[f"{name}_{i}"] = states[i] + + encoder_output.append(f"new_{name}_{i}") + + for i in range(self.num_encoder_layers): + base_name = f"layer{i}" + s = self.states[i * 4 : (i + 1) * 4] + build_inputs_outputs(s, base_name) + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py new file mode 100755 index 000000000..74da9e6c8 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/tokens.txt \ + --encoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \ + --encoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \ + --decoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \ + --decoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \ + --joiner-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \ + --joiner-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \ + ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/test_wavs/1089-134686-0001.wav + +You can find pretrained models at +https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04 +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import k2 +import ncnn +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + self.num_layers = 12 + self.memory_size = 32 + self.d_model = 512 + self.cnn_module_kernel = 31 + + self.left_context_length = 32 // 4 # after subsampling + self.chunk_length = 32 # before subsampling + right_context_length = 8 # before subsampling + pad_length = right_context_length + 2 * 4 + 3 + self.T = self.chunk_length + pad_length + print("T", self.T, self.chunk_length) + + def get_init_states(self) -> List[torch.Tensor]: + states = [] + + for i in range(self.num_layers): + s0 = torch.zeros(self.memory_size, self.d_model) + s1 = torch.zeros(self.left_context_length, self.d_model) + s2 = torch.zeros(self.left_context_length, self.d_model) + s3 = torch.zeros(self.d_model, self.cnn_module_kernel - 1) + states.extend([s0, s1, s2, s3]) + + return states + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.num_threads = 4 + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.num_threads = 4 + + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + A tensor of shape (T, C) + states: + A list of tensors. len(states) == self.num_layers * 4 + Returns: + Return a tuple containing: + - encoder_out, a tensor of shape (T, encoder_dim). + - next_states, a list of tensors containing the next states + """ + with self.encoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + # layer0 in2-in5 + # layer1 in6-in9 + for i in range(self.num_layers): + offset = 1 + i * 4 + name = f"in{offset}" + # (32, 1, 512) -> (32, 512) + ex.input(name, ncnn.Mat(states[i * 4 + 0].numpy()).clone()) + + name = f"in{offset+1}" + # (8, 1, 512) -> (8, 512) + ex.input(name, ncnn.Mat(states[i * 4 + 1].numpy()).clone()) + + name = f"in{offset+2}" + # (8, 1, 512) -> (8, 512) + ex.input(name, ncnn.Mat(states[i * 4 + 2].numpy()).clone()) + + name = f"in{offset+3}" + # (1, 512, 2) -> (512, 2) + ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + + out_states: List[torch.Tensor] = [] + for i in range(4 * self.num_layers): + name = f"out{i+1}" + ret, ncnn_out_state = ex.extract(name) + assert ret == 0, ret + ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy()) + out_states.append(ncnn_out_state) + + return encoder_out, out_states + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t] + + joiner_out = model.run_joiner(cur_encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + states = model.get_init_states() + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = model.T + offset = model.chunk_length + + chunk = int(1 * sample_rate) # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, states = model.run_encoder(frames, states) + hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(sound_file) + logging.info(text) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 71150392d..f5d894a7b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -211,8 +211,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -371,9 +370,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -390,9 +387,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,14 +546,10 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = ( - 3 * params.subsampling_factor + params.right_context_length + 3 - ) + tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -605,9 +596,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -761,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -781,10 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -831,9 +813,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -867,9 +847,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -896,9 +876,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2bbc45d78..8462ae92a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -265,8 +263,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -289,8 +286,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -636,11 +632,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -668,23 +660,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -871,9 +856,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -981,16 +964,16 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py new file mode 100755 index 000000000..dd0a60736 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -0,0 +1,1130 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conv_emformer_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conv_emformer_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 280 \ + --master-port 12321 \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +# For mix precision training: +./conv_emformer_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir conv_emformer_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 300 \ + --master-port 12321 \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from emformer2 import Emformer +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Attention dim for the Emformer", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads for the Emformer", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feed-forward dimension for the Emformer", + ) + + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of encoder layers for the Emformer", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=int, + default=31, + help="Kernel size for the convolution module.", + ) + + parser.add_argument( + "--left-context-length", + type=int, + default=32, + help="""Number of frames before subsampling for left context + in the Emformer.""", + ) + + parser.add_argument( + "--chunk-length", + type=int, + default=32, + help="""Number of frames before subsampling for each chunk + in the Emformer.""", + ) + + parser.add_argument( + "--right-context-length", + type=int, + default=8, + help="""Number of frames before subsampling for right context + in the Emformer.""", + ) + + parser.add_argument( + "--memory-size", + type=int, + default=0, + help="Number of entries in the memory for the Emformer", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for Emformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "is_pnnx": True, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Emformer( + num_features=params.feature_dim, + chunk_length=params.chunk_length, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + cnn_module_kernel=params.cnn_module_kernel, + left_context_length=params.left_context_length, + right_context_length=params.right_context_length, + memory_size=params.memory_size, + is_pnnx=params.is_pnnx, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=0.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh index 2a69d3921..6aaa0333b 100755 --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -35,7 +35,7 @@ stop_stage=4 # export CUDA_VISIBLE_DEVICES="0" # # Suppose GPU 2,3,4,5 are available. -export CUDA_VISIBLE_DEVICES="0,1,2,3" +# export CUDA_VISIBLE_DEVICES="0,1,2,3" exp_dir=./pruned_transducer_stateless6/exp mkdir -p $exp_dir @@ -43,13 +43,13 @@ mkdir -p $exp_dir # full_libri can be "True" or "False" # "True" -> use full librispeech dataset for distillation # "False" -> use train-clean-100 subset for distillation -full_libri=False +full_libri=True # use_extracted_codebook can be "True" or "False" # "True" -> stage 0 and stage 1 would be skipped, # and directly download the extracted codebook indexes for distillation # "False" -> start from scratch -use_extracted_codebook=False +use_extracted_codebook=True # teacher_model_id can be one of # "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use. @@ -145,8 +145,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960" exit 1 fi + # The codebook indexes to be downloaded are generated using the following setup: + embedding_layer=36 + num_codebooks=8 + mkdir -p $exp_dir/vq - codebook_dir=$exp_dir/vq/$teacher_model_id + codebook_dir=$exp_dir/vq/${teacher_model_id} mkdir -p codebook_dir codebook_download_dir=$exp_dir/download_codebook if [ -d $codebook_download_dir ]; then @@ -155,11 +159,18 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi log "Downloading extracted codebook indexes to $codebook_download_dir" # Make sure you have git-lfs installed (https://git-lfs.github.com) + # The codebook indexes are generated using lhotse 1.11.0, to avoid + # potential issues, we recommend you to use lhotse version >= 1.11.0 + lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))") + if [ "$lhotse_version" == "False" ]; then + log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch." + fi git lfs install - git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir + git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir - mkdir -p data/vq_fbank - mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ + vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/ + mkdir -p $vq_fbank + mv $codebook_download_dir/*.jsonl.gz $vq_fbank mkdir -p $codebook_dir/splits4 mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ log "Remove $codebook_download_dir" @@ -169,12 +180,21 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then ./pruned_transducer_stateless6/extract_codebook_index.py \ --full-libri $full_libri \ --exp-dir $exp_dir \ - --embedding-layer 36 \ + --embedding-layer $embedding_layer \ --num-utts 1000 \ - --num-codebooks 8 \ + --num-codebooks $num_codebooks \ --max-duration 100 \ --teacher-model-id $teacher_model_id \ --use-extracted-codebook $use_extracted_codebook + + if [ "$full_libri" == "True" ]; then + # Merge the 3 subsets and create a full one + rm ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c ${vq_fbank}/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > ${vq_fbank}/librispeech_cuts_train-all-shuf.jsonl.gz + fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then diff --git a/egs/librispeech/ASR/finetune.sh b/egs/librispeech/ASR/finetune.sh new file mode 100755 index 000000000..63d0966ed --- /dev/null +++ b/egs/librispeech/ASR/finetune.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +# This is an example script for fine-tuning. Here, we fine-tune a model trained +# on Librispeech on GigaSpeech. The model used for fine-tuning is +# pruned_transducer_stateless7 (zipformer). If you want to fine-tune model +# from another recipe, you can adapt ./pruned_transducer_stateless7/finetune.py +# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues. + +# We assume that you have already prepared the GigaSpeech manfiest&features under ./data. +# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/gigaspeech/ASR/prepare.sh. + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download Pre-trained model" + + # clone from huggingface + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Start fine-tuning" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt + export CUDA_VISIBLE_DEVICES="0,1" + + ./pruned_transducer_stateless7/finetune.py \ + --world-size 2 \ + --master-port 18180 \ + --num-epochs 20 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --subset S \ + --use-fp16 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \ + --do-finetune True \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 500 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decoding" + + epoch=15 + avg=10 + + for m in greedy_search modified_beam_search; do + python pruned_transducer_stateless7/decode_gigaspeech.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model True \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --max-duration 400 \ + --decoding-method $m + done +fi diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh index 6baccd381..dacd276d1 100755 --- a/egs/librispeech/ASR/generate-lm.sh +++ b/egs/librispeech/ASR/generate-lm.sh @@ -2,7 +2,7 @@ lang_dir=data/lang_bpe_500 -for ngram in 2 3 5; do +for ngram in 2 3 4 5; do if [ ! -f $lang_dir/${ngram}gram.arpa ]; then ./shared/make_kn_lm.py \ -ngram-order ${ngram} \ diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cd1bcea67..cc34a72d8 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,9 +157,7 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info( - f"{part} has {len(alignments.keys())} cuts with alignments." - ) + logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -170,18 +168,14 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info( - f"Warning: {origin_id} does not has alignment." - ) + logging.info(f"Warning: {origin_id} does not have alignment.") ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..d19d50ae6 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from Caution: We use a lexicon that contains disambiguation symbols - - G, the LM, built from data/lm/G_3_gram.fst.txt + - G, the LM, built from data/lm/G_n_gram.fst.txt The generated HLG is saved in $lang_dir/HLG.pt """ @@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon def get_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) parser.add_argument( "--lang-dir", type=str, @@ -50,11 +57,13 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str) -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. Return: An FSA representing HLG. @@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: H = k2.ctc_topo(max_token_id) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") G = k2.Fsa.from_dict(d) else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + torch.save(G.as_dict(), f"data/lm/{lm}.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] @@ -100,10 +109,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 @@ -144,15 +154,13 @@ def main(): logging.info(f"Processing {lang_dir}") - HLG = compile_HLG(lang_dir) + HLG = compile_HLG(lang_dir, args.lm) logging.info(f"Saving HLG.pt to {lang_dir}") torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py new file mode 100755 index 000000000..15fc47ef1 --- /dev/null +++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.fst + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_n_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG_fst.pt + +So when to use this script instead of ./local/compile_hlg.py ? +If you have a very large G, ./local/compile_hlg.py may throw OOM for +determinization. In that case, you can use this script to compile HLG. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> kaldifst.StdVectorFst: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. + + Return: + An FST representing HLG. + """ + + L = kaldifst.StdVectorFst.read(f"{lang_dir}/L_disambig.fst") + logging.info("Arc sort L") + kaldifst.arcsort(L, sort_type="olabel") + logging.info(f"L: #states {L.num_states}") + + G_filename_txt = f"data/lm/{lm}.fst.txt" + G_filename_binary = f"data/lm/{lm}.fst" + if Path(G_filename_binary).is_file(): + logging.info(f"Loading {G_filename_binary}") + G = kaldifst.StdVectorFst.read(G_filename_binary) + else: + logging.info(f"Loading {G_filename_txt}") + with open(G_filename_txt) as f: + G = kaldifst.compile(s=f.read(), acceptor=False) + logging.info(f"Saving G to {G_filename_binary}") + G.write(G_filename_binary) + + logging.info("Arc sort G") + kaldifst.arcsort(G, sort_type="ilabel") + + logging.info(f"G: #states {G.num_states}") + + logging.info("Compose L and G and connect LG") + LG = kaldifst.compose(L, G, connect=True) + logging.info(f"LG: #states {LG.num_states}") + + logging.info("Determinizestar LG") + kaldifst.determinize_star(LG) + logging.info(f"LG after determinize_star: #states {LG.num_states}") + + logging.info("Minimize encoded LG") + kaldifst.minimize_encoded(LG) + logging.info(f"LG after minimize_encoded: #states {LG.num_states}") + + logging.info("Converting LG to k2 format") + LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False) + logging.info(f"LG in k2: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}") + + lexicon = Lexicon(lang_dir) + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + logging.info(f"token id for #0: {first_token_disambig_id}") + logging.info(f"word id for #0: {first_word_disambig_id}") + + max_token_id = max(lexicon.tokens) + modified = False + logging.info( + f"Building ctc_topo. modified: {modified}, max_token_id: {max_token_id}" + ) + + H = k2.ctc_topo(max_token_id, modified=modified) + logging.info(f"H: #states: {H.shape[0]}, #arcs: {H.num_arcs}") + + logging.info("Removing disambiguation symbols on LG") + LG.labels[LG.labels >= first_token_disambig_id] = 0 + LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0 + + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + logging.info("Removing epsilons from LG") + LG = k2.remove_epsilon(LG) + logging.info( + f"LG after k2.remove_epsilon: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}" + ) + + logging.info("Connecting LG after removing epsilons") + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + logging.info(f"LG after k2.connect: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}") + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + + HLG = k2.compose(H, LG, inner_labels="tokens") + logging.info( + f"HLG after k2.compose: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}" + ) + + logging.info("Connecting HLG") + HLG = k2.connect(HLG) + logging.info( + f"HLG after k2.connect: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}" + ) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + filename = lang_dir / "HLG_fst.pt" + + if filename.is_file(): + logging.info(f"{filename} already exists - skipping") + return + + HLG = compile_HLG(lang_dir, args.lm) + logging.info(f"Saving HLG to {filename}") + torch.save(HLG.as_dict(), filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 45c4b7f5f..709b14070 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -45,11 +45,18 @@ def get_args(): help="""Input and output directory. """, ) + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) return parser.parse_args() -def compile_LG(lang_dir: str) -> k2.Fsa: +def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -61,15 +68,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa: lexicon = Lexicon(lang_dir) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") G = k2.Fsa.from_dict(d) else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + torch.save(G.as_dict(), f"data/lm/{lm}.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] @@ -96,10 +103,11 @@ def compile_LG(lang_dir: str) -> k2.Fsa: logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels assert isinstance(LG.aux_labels, k2.RaggedTensor) LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 @@ -126,15 +134,13 @@ def main(): logging.info(f"Processing {lang_dir}") - LG = compile_LG(lang_dir) + LG = compile_LG(lang_dir, args.lm) logging.info(f"Saving LG.pt to {lang_dir}") torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index c0c7ef8c5..97750f3ea 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,9 +80,7 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 5587106e5..ce0ef24e7 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -144,9 +144,7 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index ce7d087f0..25d6050bb 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -35,7 +35,7 @@ from filter_cuts import filter_cuts from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -54,10 +54,28 @@ def get_args(): help="""Path to the bpe.model. If not None, we will remove short and long utterances before extracting features""", ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + return parser.parse_args() -def compute_fbank_librispeech(bpe_model: Optional[str] = None): +def compute_fbank_librispeech( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -68,15 +86,19 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): sp = spm.SentencePieceProcessor() sp.load(bpe_model) - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + prefix = "librispeech" suffix = "jsonl.gz" manifests = read_manifests_if_cached( @@ -107,15 +129,17 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): recordings=m["recordings"], supervisions=m["supervisions"], ) - if bpe_model: - cut_set = filter_cuts(cut_set, sp) if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -128,11 +152,13 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_librispeech(bpe_model=args.bpe_model) + compute_fbank_librispeech( + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 056da29e5..62036467e 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -41,6 +41,10 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) +def is_cut_long(c: MonoCut) -> bool: + return c.duration > 5 + + def compute_fbank_musan(): src_dir = Path("data/manifests") output_dir = Path("data/fbank") @@ -83,12 +87,10 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) + recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) + .filter(is_cut_long) .compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/musan_feats", @@ -101,9 +103,7 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index 133499c8b..a8d5117c9 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -51,16 +51,12 @@ def get_args(): "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) + parser.add_argument("--oov", type=str, default="", help="The OOV word.") return parser.parse_args() -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 030122aa7..3518db524 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,9 +87,7 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index 53dbb8211..fbcc9e24a 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,8 +79,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) removed += 1 return False @@ -101,6 +100,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): # Note: for ./lstm_transducer_stateless/lstm.py, the formula is # T = ((num_frames - 3) // 2 - 1) // 2 + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): @@ -122,8 +124,7 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. " - f"{ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." ) return ans @@ -152,9 +153,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 566c0743d..3459c2f5a 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,9 +91,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index dec8a7442..2a2d9c219 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -127,7 +127,7 @@ def lexicon_to_fst_no_sil( def generate_lexicon( - model_file: str, words: List[str] + model_file: str, words: List[str], oov: str ) -> Tuple[Lexicon, Dict[str, int]]: """Generate a lexicon from a BPE model. @@ -136,6 +136,8 @@ def generate_lexicon( Path to a sentencepiece model. words: A list of strings representing words. + oov: + The out of vocabulary word in lexicon. Returns: Return a tuple with two elements: - A dict whose keys are words and values are the corresponding @@ -150,20 +152,15 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [ - sp.id_to_piece(ids) for ids in words_pieces_ids - ] + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] lexicon = [] for word, pieces in zip(words, words_pieces): lexicon.append((word, pieces)) - # The OOV word is - lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) - token2id: Dict[str, int] = dict() - for i in range(sp.vocab_size()): - token2id[sp.id_to_piece(i)] = i + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} return lexicon, token2id @@ -178,6 +175,13 @@ def get_args(): """, ) + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + parser.add_argument( "--debug", type=str2bool, @@ -204,12 +208,13 @@ def main(): words = word_sym_table.symbols - excluded = ["", "!SIL", "", "", "#0", "", ""] + excluded = ["", "!SIL", "", args.oov, "#0", "", ""] + for w in excluded: if w in words: words.remove(w) - lexicon, token_sym_table = generate_lexicon(model_file, words) + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 5070341f1..70343fef7 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,8 +137,7 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} " - f"({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -154,18 +153,14 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor( - sentence_lengths, dtype=torch.int32 - ) + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 077f23039..8aa5e461d 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,9 +119,7 @@ def preprocess_giga_speech(): def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index d4cf62bba..74e025ad7 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,9 +88,7 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) + fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 42aba9572..43142aee4 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -89,6 +89,9 @@ def main(): bos_id=-1, eos_id=-1, ) + else: + print(f"{model_file} exists - skipping") + return shutil.copyfile(model_file, f"{lang_dir}/bpe.model") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 7c57d629a..de49f5321 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -35,6 +35,7 @@ from pathlib import Path from lhotse import CutSet, load_manifest_lazy from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr def get_args(): @@ -55,16 +56,22 @@ def validate_one_supervision_per_cut(c: Cut): def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' s = c.supervisions[0] - if s.start < c.start: - raise ValueError( - f"{c.id}: Supervision start time {s.start} is less " - f"than cut start time {c.start}" - ) - if s.end > c.end: + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " f"than cut end time {c.end}" ) @@ -83,11 +90,15 @@ def main(): validate_one_supervision_per_cut(c) validate_supervision_and_cut_time_bounds(c) + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 27414d717..856c9d945 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -272,8 +272,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -366,9 +365,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -427,10 +424,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -561,9 +555,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -574,18 +566,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -595,10 +583,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -648,9 +633,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -682,9 +665,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -711,9 +694,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -772,9 +755,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py new file mode 120000 index 000000000..9f5064deb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index 13dac6009..e338342cc 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py index 594c33e4f..c07956243 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -123,10 +123,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index c54a4c478..bbab16af7 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -672,9 +672,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -771,16 +769,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index d71132b4a..e7bad7ed8 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -151,9 +149,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py new file mode 120000 index 000000000..0b1ea0326 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py new file mode 120000 index 000000000..099c2882f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 2a6e2adc6..b3a34a9e3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +197,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index 97d890c82..d8f7fd960 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -70,14 +70,12 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) self.hyp: Optional[List[int]] = None else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index d6376bdc0..f989d9bc0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -752,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -772,10 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -816,9 +799,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +833,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +862,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index d30fc260a..feb81d500 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -222,8 +220,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -246,8 +243,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -594,11 +590,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -638,9 +630,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -653,14 +643,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -671,9 +656,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -856,9 +839,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -973,10 +954,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -989,8 +970,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index bad4e243e..1a724830b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -93,22 +93,40 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search (with RNNLM shallow fusion) +(8) modified beam search (with LM shallow fusion) ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ --avg 15 \ --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ - --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --decoding-method modified_beam_search_lm_shallow_fusion \ --beam 4 \ - --rnn-lm-scale 0.3 \ - --rnn-lm-exp-dir /path/to/RNNLM \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -""" +(9) modified beam search with LM shallow fusion + LODR +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --decoding-method modified_beam_search_LODR \ + --beam 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ +""" import argparse import logging @@ -131,13 +149,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, modified_beam_search_ngram_rescoring, - modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -145,7 +164,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -235,7 +253,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_lm_shallow_fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -295,8 +314,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -326,67 +344,28 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. """, ) @@ -395,7 +374,8 @@ def get_parser(): type=int, default=3, help="""Token Ngram used for rescoring. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", ) parser.add_argument( @@ -403,7 +383,8 @@ def get_parser(): type=int, default=500, help="""ID of the backoff symbol. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", ) add_model_arguments(parser) @@ -420,8 +401,7 @@ def decode_one_batch( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -450,6 +430,9 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -474,9 +457,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -535,10 +516,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -566,15 +544,25 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -629,8 +617,7 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, ngram_lm: Optional[NgramLm] = None, ngram_lm_scale: float = 1.0, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -649,6 +636,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -682,8 +671,7 @@ def decode_dataset( batch=batch, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -700,9 +688,8 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -713,18 +700,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -734,10 +717,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -755,6 +735,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -769,8 +750,9 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", "modified_beam_search_ngram_rescoring", - "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -789,15 +771,22 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if "rnnlm" in params.decoding_method: - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -826,9 +815,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -860,9 +849,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -912,7 +901,7 @@ def main(): model.eval() # only load N-gram LM when needed - if "ngram" in params.decoding_method: + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: lm_filename = f"{params.tokens_ngram}gram.fst.txt" logging.info(f"lm filename: {lm_filename}") ngram_lm = NgramLm( @@ -921,33 +910,23 @@ def main(): is_binary=False, ) logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale else: ngram_lm = None ngram_lm_scale = None - # only load rnnlm if used - if "rnnlm" in params.decoding_method: - rnn_lm_scale = params.rnn_lm_scale - - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - assert params.rnn_lm_avg == 1 - - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - rnn_lm_model.eval() - + LM.to(device) + LM.eval() else: - rnn_lm_model = None - rnn_lm_scale = 0.0 + LM = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -961,9 +940,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -977,7 +954,9 @@ def main(): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) test_clean_cuts = librispeech.test_clean_cuts() + # test_clean_cuts = test_clean_cuts.subset(first=500) test_other_cuts = librispeech.test_other_cuts() + # test_other_cuts = test_other_cuts.subset(first=500) test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) @@ -995,8 +974,7 @@ def main(): decoding_graph=decoding_graph, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py new file mode 100755 index 000000000..08bfcb204 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export via torch.jit.trace() + +./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + +cd ./lstm_transducer_stateless2/exp +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + params.is_pnnx = True + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py new file mode 100755 index 000000000..f068f6a0f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -0,0 +1,632 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lexicon.txt" +git lfs pull --include "data/L.pt" +git lfs pull --include "exp/epoch-11.pt" +git lfs pull --include "exp/epoch-10.pt" + +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx-zh.py \ + --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ + --use-averaged-model 1 \ + --epoch 11 \ + --avg 1 \ + --exp-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/exp \ + --num-encoder-layers 12 \ + --encoder-dim 512 \ + --rnn-hidden-size 1024 + +It will generate the following files inside $repo/exp: + + - encoder-epoch-11-avg-1.onnx + - decoder-epoch-11-avg-1.onnx + - joiner-epoch-11-avg-1.onnx + - encoder-epoch-11-avg-1.int8.onnx + - decoder-epoch-11-avg-1.int8.onnx + - joiner-epoch-11-avg-1.int8.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Optional, Tuple + +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from lstm import RNN +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for RNN and the encoder_proj from the joiner""" + + def __init__(self, encoder: RNN, encoder_proj: nn.Linear): + """ + Args: + encoder: + An RNN encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of RNN.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - updated states, whose shape is the same as the input states. + """ + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device) + encoder_out, _, next_states = self.encoder(x, x_lens, states) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, next_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + d_model = encoder_model.encoder.d_model + rnn_hidden_size = encoder_model.encoder.rnn_hidden_size + + decode_chunk_len = 4 + T = 9 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.get_init_states() + # state0: (num_encoder_layers, batch_size, d_model) + # state1: (num_encoder_layers, batch_size, rnn_hidden_size) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "state0", "state1"], + output_names=["encoder_out", "new_state0", "new_state1"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "state0": {1: "N"}, + "state1": {1: "N"}, + "encoder_out": {0: "N"}, + "new_state0": {1: "N"}, + "new_state1": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "lstm", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": str(num_encoder_layers), + "d_model": str(d_model), + "rnn_hidden_size": str(rnn_hidden_size), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py new file mode 100755 index 000000000..acaff8540 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -0,0 +1,626 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + - encoder-epoch-99-avg-1.int8.onnx + - decoder-epoch-99-avg-1.int8.onnx + - joiner-epoch-99-avg-1.int8.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Optional, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from lstm import RNN +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for RNN and the encoder_proj from the joiner""" + + def __init__(self, encoder: RNN, encoder_proj: nn.Linear): + """ + Args: + encoder: + An RNN encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of RNN.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - updated states, whose shape is the same as the input states. + """ + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device) + encoder_out, _, next_states = self.encoder(x, x_lens, states) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, next_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + d_model = encoder_model.encoder.d_model + rnn_hidden_size = encoder_model.encoder.rnn_hidden_size + + decode_chunk_len = 4 + T = 9 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.get_init_states() + # state0: (num_encoder_layers, batch_size, d_model) + # state1: (num_encoder_layers, batch_size, rnn_hidden_size) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "state0", "state1"], + output_names=["encoder_out", "new_state0", "new_state1"], + dynamic_axes={ + "x": {0: "N"}, + "state0": {1: "N"}, + "state1": {1: "N"}, + "encoder_out": {0: "N"}, + "new_state0": {1: "N"}, + "new_state1": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "lstm", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": str(num_encoder_layers), + "d_model": str(d_model), + "rnn_hidden_size": str(rnn_hidden_size), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 190673638..0adc68112 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -74,29 +74,6 @@ with the following commands: git lfs install git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 # You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp - -(3) Export to ONNX format - -./lstm_transducer_stateless2/export.py \ - --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./streaming-onnx-decode.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. """ import argparse @@ -192,41 +169,11 @@ def get_parser(): """, ) - parser.add_argument( - "--pnnx", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace for later - converting to PNNX. It will generate 3 files: - - encoder_jit_trace-pnnx.pt - - decoder_jit_trace-pnnx.pt - - joiner_jit_trace-pnnx.pt - """, - ) - - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit and --pnnx are ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -306,215 +253,6 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has 3 inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - states: a tuple containing: - - h0: a tensor of shape (num_layers, N, proj_size) - - c0: a tensor of shape (num_layers, N, hidden_size) - - and it has 3 outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - states: a tuple containing: - - next_h0: a tensor of shape (num_layers, N, proj_size) - - next_c0: a tensor of shape (num_layers, N, hidden_size) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - N = 1 - x = torch.zeros(N, 9, 80, dtype=torch.float32) - x_lens = torch.tensor([9], dtype=torch.int64) - h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand( - encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size - ) - - warmup = 1.0 - torch.onnx.export( - encoder_model, # use torch.jit.trace() internally - (x, x_lens, (h, c), warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "h", "c", "warmup"], - output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "h": {1: "N"}, - "c": {1: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - "next_h": {1: "N"}, - "next_c": {1: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) - - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "projected_encoder_out", - "projected_decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "projected_encoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -538,10 +276,6 @@ def main(): logging.info(params) - if params.pnnx: - params.is_pnnx = params.pnnx - logging.info("For PNNX") - logging.info("About to create model") model = get_transducer_model(params, enable_giga=False) @@ -550,9 +284,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -585,9 +319,9 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -636,44 +370,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx: - logging.info("Export model to ONNX format") - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - opset_version = 11 - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - - elif params.pnnx: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - elif params.jit_trace is True: + if params.jit_trace is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace.pt" @@ -694,9 +391,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index da184b76f..728b09104 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -124,10 +124,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -315,9 +314,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index fadeb4ac2..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 410de8d3d..3eeaa5397 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -19,7 +19,7 @@ """ Usage: ./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ @@ -27,15 +27,19 @@ Usage: --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ ./test_wavs/1089-134686-0001.wav + +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for details. """ import argparse import logging from typing import List +import k2 import kaldifeat import ncnn -import sentencepiece as spm import torch import torchaudio @@ -44,9 +48,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--bpe-model-filename", + "--tokens", type=str, - help="Path to bpe.model", + help="Path to tokens.txt", ) parser.add_argument( @@ -104,6 +108,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -118,6 +124,7 @@ class Model: decoder_net = ncnn.Net() decoder_net.opt.use_packing_layout = False + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -129,6 +136,8 @@ class Model: joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() joiner_net.opt.use_packing_layout = False + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -136,7 +145,6 @@ class Model: def run_encoder(self, x, states): with self.encoder_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(x.numpy()).clone()) x_lens = torch.tensor([x.size(0)], dtype=torch.float32) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) @@ -156,9 +164,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -167,7 +173,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -176,7 +181,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - ex.set_num_threads(10) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") @@ -200,10 +204,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -241,9 +244,6 @@ def main(): model = Model(args) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model_filename) - sound_file = args.sound_filename sample_rate = 16000 @@ -281,14 +281,20 @@ def main(): encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) hyp = greedy_search(model, encoder_out) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + logging.info(sound_file) - logging.info(sp.decode(hyp)) + logging.info(text) if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py new file mode 100755 index 000000000..c83f38b2a --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./lstm_transducer_stateless2/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit-trace 1 + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./lstm_transducer_stateless2/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +""" + +import argparse +import logging + +from onnx_pretrained import OnnxModel + +from icefall import is_module_available + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = torch_encoder_model.get_init_states(N) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..fb9e121e5 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" +cd exp +ln -s exp/pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./lstm_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "lstm", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + d_model = int(encoder_meta["d_model"]) + rnn_hidden_size = int(encoder_meta["rnn_hidden_size"]) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"d_model: {d_model}") + logging.info(f"rnn_hidden_size: {rnn_hidden_size}") + + N = batch_size + + s0 = torch.zeros(num_encoder_layers, N, d_model) + s1 = torch.zeros(num_encoder_layers, N, rnn_hidden_size) + states = [s0.numpy(), s1.numpy()] + + self.states = states + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = { + "x": x.numpy(), + "state0": self.states[0], + "state1": self.states[1], + } + encoder_output = ["encoder_out", "new_state0", "new_state1"] + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-3)//2-1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index bef0ad760..f3f272b9f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -169,8 +169,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,10 +200,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -267,15 +265,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -347,9 +341,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index e47a05a9e..cbbc77928 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -16,13 +16,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for usage +""" import argparse import logging from typing import List, Optional +import k2 import ncnn -import sentencepiece as spm import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -32,9 +37,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--bpe-model-filename", + "--tokens", type=str, - help="Path to bpe.model", + help="Path to tokens.txt", ) parser.add_argument( @@ -92,6 +97,8 @@ class Model: encoder_net = ncnn.Net() encoder_net.opt.use_packing_layout = False encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + encoder_param = args.encoder_param_filename encoder_model = args.encoder_bin_filename @@ -106,6 +113,7 @@ class Model: decoder_net = ncnn.Net() decoder_net.opt.use_packing_layout = False + decoder_net.opt.num_threads = 4 decoder_net.load_param(decoder_param) decoder_net.load_model(decoder_model) @@ -117,6 +125,8 @@ class Model: joiner_model = args.joiner_bin_filename joiner_net = ncnn.Net() joiner_net.opt.use_packing_layout = False + joiner_net.opt.num_threads = 4 + joiner_net.load_param(joiner_param) joiner_net.load_model(joiner_model) @@ -124,7 +134,6 @@ class Model: def run_encoder(self, x, states): with self.encoder_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(x.numpy()).clone()) x_lens = torch.tensor([x.size(0)], dtype=torch.float32) ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) @@ -144,9 +153,7 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( - torch.int32 - ) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -155,7 +162,6 @@ class Model: assert decoder_input.dtype == torch.int32 with self.decoder_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") assert ret == 0, ret @@ -164,7 +170,6 @@ class Model: def run_joiner(self, encoder_out, decoder_out): with self.joiner_net.create_extractor() as ex: - # ex.set_num_threads(10) ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) ret, ncnn_out0 = ex.extract("out0") @@ -188,10 +193,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -229,9 +233,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - hyp, dtype=torch.int32 - ) # (1, context_size) + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -254,9 +256,6 @@ def main(): model = Model(args) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model_filename) - sound_file = args.sound_filename sample_rate = 16000 @@ -310,9 +309,7 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -328,24 +325,26 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder( - frames, states - ) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp ) + symbol_table = k2.SymbolTable.from_file(args.tokens) + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() logging.info(sound_file) - logging.info(sp.decode(hyp[context_size:])) + logging.info(text) if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 232d3dd18..34d2e5630 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -147,10 +147,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -199,9 +198,7 @@ class Model: sess_options=self.session_opts, ) - def run_encoder( - self, x, h0, c0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -258,9 +255,7 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj( - torch.from_numpy(decoder_out).squeeze(1) - ) + return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) def run_joiner( self, @@ -303,11 +298,7 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - { - self.joiner_encoder_proj.get_inputs()[ - 0 - ].name: encoder_out.numpy() - }, + {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, )[0] return torch.from_numpy(projected_encoder_out) @@ -326,11 +317,7 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - { - self.joiner_decoder_proj.get_inputs()[ - 0 - ].name: decoder_out.numpy() - }, + {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, )[0] return torch.from_numpy(projected_decoder_out) @@ -369,9 +356,7 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [hyp], dtype=torch.int64 - ) # (1, context_size) + decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -474,9 +459,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 5eaaf321f..4fc4fa7f8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,9 +95,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -238,8 +235,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -262,8 +258,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -692,9 +683,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -707,14 +696,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -725,9 +709,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -958,9 +940,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1006,8 +986,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1129,10 +1108,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) @@ -1155,9 +1134,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 9eee19379..a2b4f9e1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -94,10 +94,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./lstm_transducer_stateless3/decode.py \ @@ -290,8 +287,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -386,9 +382,7 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -441,10 +435,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -490,7 +481,6 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, @@ -522,9 +512,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -599,9 +587,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -610,9 +596,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -627,18 +611,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -649,10 +629,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -660,8 +637,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) @@ -678,9 +654,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -724,9 +698,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -758,9 +730,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -787,9 +759,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -848,9 +820,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py new file mode 120000 index 000000000..d56cff73f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-for-ncnn.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py new file mode 120000 index 000000000..9f5064deb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 212c7bad6..a82cad043 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -172,8 +172,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -281,9 +280,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -310,9 +309,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -380,9 +379,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index a3443cf0a..237591a36 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -123,10 +123,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -314,9 +313,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 90bc351f4..59a835d35 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -121,6 +121,8 @@ class RNN(EncoderInterface): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. """ def __init__( @@ -135,6 +137,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -148,7 +151,13 @@ class RNN(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx self.num_encoder_layers = num_encoder_layers self.d_model = d_model @@ -216,7 +225,13 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() @@ -377,7 +392,7 @@ class RNNEncoderLayer(nn.Module): # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) + src = self.dropout(src_lstm) + src # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -523,6 +538,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + is_pnnx: bool = False, ) -> None: """ Args: @@ -535,6 +551,9 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. """ assert in_channels >= 9 super().__init__() @@ -577,6 +596,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -590,9 +613,15 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) @@ -661,9 +690,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -760,16 +787,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py new file mode 120000 index 000000000..0b1ea0326 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py new file mode 120000 index 000000000..099c2882f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 0e48fef04..f49e9c518 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +197,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index cfa918ed5..c737e3611 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -199,8 +199,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -359,9 +358,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -378,9 +375,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,9 +534,7 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor( - num_processed_frames_list, device=device - ) + num_processed_frames = torch.tensor(num_processed_frames_list, device=device) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -583,8 +576,7 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor - + encoder_out_lens + num_processed_frames // params.subsampling_factor + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -596,9 +588,7 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") # Update cached states of each stream state_list = unstack_states(states) @@ -752,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -772,10 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -816,9 +799,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -852,9 +833,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -881,9 +862,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 60a5a2be7..6ef4c9860 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,9 +87,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -104,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dim", type=int, default=512, - help="Encoder output dimesion.", + help="Encoder output dimension.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( @@ -232,8 +251,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -256,8 +274,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -399,14 +416,10 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + "is_pnnx": False, } ) @@ -423,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder @@ -606,11 +620,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,9 +660,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -665,14 +673,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -683,9 +686,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -852,10 +853,7 @@ def train_one_epoch( rank=rank, ) - if ( - batch_idx % params.log_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.log_interval == 0 and not params.print_diagnostics: cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -872,9 +870,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if ( batch_idx > 0 @@ -1009,8 +1005,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 94e003036..8342d5212 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 @@ -41,9 +44,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + # 5000 + # 2000 + # 1000 500 ) @@ -120,6 +123,13 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then touch data/fbank/.librispeech.done fi + if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + fi + if [ ! -e data/fbank/.librispeech-validated.done ]; then log "Validating data/fbank for LibriSpeech" parts=( @@ -160,6 +170,22 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang.py --lang-dir $lang_dir fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi fi @@ -200,11 +226,27 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then --lexicon $lang_dir/lexicon.txt \ --bpe-model $lang_dir/bpe.model fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi done fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram P" + log "Stage 7: Prepare bigram token-level P for MMI training" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} @@ -263,9 +305,19 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then log "Stage 9: Compile HLG" ./local/compile_hlg.py --lang-dir data/lang_phone + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir done fi diff --git a/egs/librispeech/ASR/prepare_common_voice.sh b/egs/librispeech/ASR/prepare_common_voice.sh new file mode 100755 index 000000000..6f9c4fb2f --- /dev/null +++ b/egs/librispeech/ASR/prepare_common_voice.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# Split data/${lang}set to this number of pieces +# This is to avoid OOM during feature extraction. +num_splits=1000 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/$release/$lang +# This directory contains the following files downloaded from +# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz +# +# - clips +# - dev.tsv +# - invalidated.tsv +# - other.tsv +# - reported.tsv +# - test.tsv +# - train.tsv +# - validated.tsv + +dl_dir=$PWD/download +release=cv-corpus-13.0-2023-03-09 +lang=en + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/$release/$lang/clips ]; then + lhotse download commonvoice --languages $lang --release $release $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare CommonVoice manifest" + # We assume that you have downloaded the CommonVoice corpus + # to $dl_dir/$release + mkdir -p data/${lang}/manifests + if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then + lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Preprocess CommonVoice manifest" + if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then + ./local/preprocess_commonvoice.py --language $lang + touch data/${lang}/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for dev and test subsets of CommonVoice" + mkdir -p data/${lang}/fbank + if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then + ./local/compute_fbank_commonvoice_dev_test.py --language $lang + touch data/${lang}/fbank/.cv-${lang}_dev_test.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split train subset into ${num_splits} pieces" + split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits} + if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_train_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute features for train subset of CommonVoice" + if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then + ./local/compute_fbank_commonvoice_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang + touch data/${lang}/fbank/.cv-${lang}_train.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Combine features for train" + if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then + pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz") + lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz + fi +fi diff --git a/egs/librispeech/ASR/prepare_giga_speech.sh b/egs/librispeech/ASR/prepare_giga_speech.sh index 6f85ddc29..b077aaf3a 100755 --- a/egs/librispeech/ASR/prepare_giga_speech.sh +++ b/egs/librispeech/ASR/prepare_giga_speech.sh @@ -95,39 +95,45 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)" # We assume that you have downloaded the GigaSpeech corpus # to $dl_dir/GigaSpeech - mkdir -p data/manifests - lhotse prepare gigaspeech \ - --subset XL \ - --subset L \ - --subset M \ - --subset S \ - --subset XS \ - --subset DEV \ - --subset TEST \ - -j $nj \ - $dl_dir/GigaSpeech data/manifests + if [ ! -f data/manifests/.gigaspeech.done ]; then + mkdir -p data/manifests + lhotse prepare gigaspeech \ + --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + -j $nj \ + $dl_dir/GigaSpeech data/manifests + touch data/manifests/.gigaspeech.done + fi fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Preprocess GigaSpeech manifest" - if [ ! -f data/fbank/.preprocess_complete ]; then - log "It may take 2 hours for this stage" - python3 ./local/preprocess_gigaspeech.py - touch data/fbank/.preprocess_complete + if [ ! -f data/fbank/.gigaspeech_preprocess.done ]; then + log "It may take 2 hours for this stage" + ./local/preprocess_gigaspeech.py + touch data/fbank/.gigaspeech_preprocess.done fi fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)" - python3 ./local/compute_fbank_gigaspeech_dev_test.py + if [ ! -f data/fbank/.gigaspeech_dev_test.done ]; then + ./local/compute_fbank_gigaspeech_dev_test.py + touch data/fbank/.gigaspeech_dev_test.done + fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Split XL subset into ${num_splits} pieces" split_dir=data/fbank/gigaspeech_XL_split_${num_splits} - if [ ! -f $split_dir/.split_completed ]; then + if [ ! -f $split_dir/.gigaspeech_XL_split.done ]; then lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $chunk_size - touch $split_dir/.split_completed + touch $split_dir/.gigaspeech_XL_split.done fi fi @@ -135,8 +141,19 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute features for XL" # Note: The script supports --start and --stop options. # You can use several machines to compute the features in parallel. - python3 ./local/compute_fbank_gigaspeech_splits.py \ - --num-workers $nj \ - --batch-duration 600 \ - --num-splits $num_splits + if [ ! -f data/fbank/.gigaspeech_XL.done ]; then + ./local/compute_fbank_gigaspeech_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --num-splits $num_splits + touch data/fbank/.gigaspeech_XL.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Combine features for XL (may take 15 hours)" + if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then + pieces=$(find data/fbank/gigaspeech_XL_split_${num_splits} -name "gigaspeech_cuts_XL.*.jsonl.gz") + lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz + fi fi diff --git a/egs/librispeech/ASR/prepare_multidataset.sh b/egs/librispeech/ASR/prepare_multidataset.sh new file mode 100755 index 000000000..c95b4d039 --- /dev/null +++ b/egs/librispeech/ASR/prepare_multidataset.sh @@ -0,0 +1,330 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +# Split all dataset to this number of pieces and mix each dataset pieces +# into multidataset pieces with shuffling. +num_splits=1998 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# multidataset list. +# LibriSpeech and musan are required. +# The others are optional. +multidataset=( + "gigaspeech", + "commonvoice", +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: LibriSpeech and musan" +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" + mkdir -p $dl_dir/lm + if [ ! -e $dl_dir/lm/.done ]; then + ./local/download_lm.py --out-dir=$dl_dir/lm + touch $dl_dir/lm/.done + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriSpeech, + # you can create a symlink + # + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech + # + if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then + lhotse download librispeech --full $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.librispeech.done ]; then + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests + touch data/manifests/.librispeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for librispeech" + mkdir -p data/fbank + if [ ! -e data/fbank/.librispeech.done ]; then + ./local/compute_fbank_librispeech.py --perturb-speed False + touch data/fbank/.librispeech.done + fi + + if [ ! -f data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -e data/fbank/.librispeech-validated.done ]; then + log "Validating data/fbank for LibriSpeech" + parts=( + train-clean-100 + train-clean-360 + train-other-500 + test-clean + test-other + dev-clean + dev-other + ) + for part in ${parts[@]}; do + python3 ./local/validate_manifest.py \ + data/fbank/librispeech_cuts_${part}.jsonl.gz + done + touch data/fbank/.librispeech-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Prepare the other datasets" + # GigaSpeech + if [[ "${multidataset[@]}" =~ "gigaspeech" ]]; then + log "Dataset: GigaSpeech" + ./prepare_giga_speech.sh --stop_stage 5 + fi + + # CommonVoice + if [[ "${multidataset[@]}" =~ "commonvoice" ]]; then + log "Dataset: CommonVoice" + ./prepare_common_voice.sh + fi +fi diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index 8dd1459ca..b839a4a4c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -83,8 +83,7 @@ class LibriSpeechAsrDataModule: "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -208,13 +207,9 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -236,9 +231,7 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -281,9 +274,7 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -340,9 +331,7 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: @@ -389,23 +378,17 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) + return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 2e9bf3e0b..785a8f097 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,9 +302,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -320,9 +318,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -496,9 +492,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 295a35204..de367c234 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple -from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface +from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - - self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, - knowledge_D) + self.knowledge_base = create_knowledge_base( + knowledge_M, knowledge_N, knowledge_D + ) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( + encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K + knowledge_K, ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,9 +187,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -209,10 +207,9 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, - knowledge_D, knowledge_K, - d_model, - knowledge_base) + self.lookup = KnowledgeBaseLookup( + knowledge_M, knowledge_N, knowledge_D, knowledge_K, d_model, knowledge_base + ) self.norm_final = BasicNorm(d_model) @@ -311,9 +308,7 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList( - [encoder_layer_fn() for i in range(num_layers)] - ) + self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) self.num_layers = num_layers def forward( @@ -367,9 +362,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -384,9 +377,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -661,9 +652,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -732,31 +723,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -795,9 +777,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -805,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -845,13 +821,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -874,9 +846,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index b4a9af55a..82fd103ea 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,11 +76,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -186,8 +182,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -245,9 +240,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -262,10 +255,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -385,9 +375,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -398,17 +386,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -418,10 +402,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index b6d94aaf1..0b9c886c7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,9 +90,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index db51fb1cd..0c9cee431 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional from subsampling import ScaledConv1d +from torch import Tensor class Decoder(nn.Module): @@ -90,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -102,7 +101,6 @@ class Decoder(nn.Module): return embedding_out - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -171,8 +169,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -181,34 +184,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -217,22 +227,35 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 96d1a30fb..51020aa30 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -174,9 +173,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 35f75ed2a..68c663b66 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,9 +56,7 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 599bf2506..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,9 +63,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -136,9 +134,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 432bf8220..76cd4e11e 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -176,18 +166,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -295,10 +281,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 7b05e2f00..5b595c76c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,32 +3,29 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import timeit -import torch -from torch import Tensor -from torch import nn -from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd -from typing import Tuple, Optional -from scaling import ScaledLinear import random +import timeit +from typing import Optional, Tuple + +import torch +from scaling import ScaledLinear +from torch import Tensor, nn +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. - - - - def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M ** N, D)) + a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M**N, D)) nn.init.uniform_(ans, -a, a) return ans + def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -47,9 +44,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup(weights: Tensor, - indexes: Tensor, - knowledge_base: Tensor) -> Tensor: +def weighted_matrix_lookup( + weights: Tensor, indexes: Tensor, knowledge_base: Tensor +) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -65,9 +62,9 @@ def weighted_matrix_lookup(weights: Tensor, # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -76,7 +73,9 @@ def weighted_matrix_lookup(weights: Tensor, class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: + def forward( + ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor + ) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -88,15 +87,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward(weights.detach(), indexes.detach(), - knowledge_base.detach()) + ctx.save_for_backward( + weights.detach(), indexes.detach(), knowledge_base.detach() + ) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) #(*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) # (*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad == False + assert weights.requires_grad is False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,16 +115,19 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul(lookup, # (*, K, D) - ans_grad.unsqueeze(-1)) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul( + lookup, ans_grad.unsqueeze(-1) # (*, K, D) + ) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze( + -2 + ) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -146,6 +149,7 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ + @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -154,18 +158,23 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - logprobs, = ctx.saved_tensors + (logprobs,) = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) + l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) + print( + "Negentropy[individual,combined] = ", + negentropy_individual.item(), + ", ", + negentropy.item(), + ) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -183,18 +192,23 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - def __init__(self, M: int, N: int, D: int, - K: int, embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001): + + def __init__( + self, + M: int, + N: int, + D: int, + K: int, + embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001, + ): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, - initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, - initial_scale = 4.0) + self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) self.M = M self.N = N self.K = K @@ -210,14 +224,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -237,38 +251,44 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + ( + torch.randn(B, T, E, device=device, dtype=dtype), + torch.randn(B, T, E, device=device, dtype=dtype), + ) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) - start = timeit.default_timer() -# Epoch 0, batch 0, loss 1.0109944343566895 -# Epoch 10, batch 0, loss 1.0146660804748535 -# Epoch 20, batch 0, loss 1.0119813680648804 -# Epoch 30, batch 0, loss 1.0105408430099487 -# Epoch 40, batch 0, loss 1.0077732801437378 -# Epoch 50, batch 0, loss 1.0050103664398193 -# Epoch 60, batch 0, loss 1.0033129453659058 -# Epoch 70, batch 0, loss 1.0014232397079468 -# Epoch 80, batch 0, loss 0.9977912306785583 -# Epoch 90, batch 0, loss 0.8274348974227905 -# Epoch 100, batch 0, loss 0.3368612825870514 -# Epoch 110, batch 0, loss 0.11323091387748718 -# Time taken: 17.591704960912466 + # Epoch 0, batch 0, loss 1.0109944343566895 + # Epoch 10, batch 0, loss 1.0146660804748535 + # Epoch 20, batch 0, loss 1.0119813680648804 + # Epoch 30, batch 0, loss 1.0105408430099487 + # Epoch 40, batch 0, loss 1.0077732801437378 + # Epoch 50, batch 0, loss 1.0050103664398193 + # Epoch 60, batch 0, loss 1.0033129453659058 + # Epoch 70, batch 0, loss 1.0014232397079468 + # Epoch 80, batch 0, loss 0.9977912306785583 + # Epoch 90, batch 0, loss 0.8274348974227905 + # Epoch 100, batch 0, loss 0.3368612825870514 + # Epoch 110, batch 0, loss 0.11323091387748718 + # Time taken: 17.591704960912466 for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -276,7 +296,8 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) + def _test_knowledge_base_lookup_autocast(): K = 16 @@ -294,14 +315,18 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device('cuda') - train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] + device = torch.device("cuda") + train_pairs = [ + (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) + for _ in range(10) + ] from optim import Eve + optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -309,12 +334,11 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() - for epoch in range(150): - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -323,10 +347,9 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print('Time taken: ', stop - start) + print("Time taken: ", stop - start) - -if __name__ == '__main__': +if __name__ == "__main__": _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index f726c2583..527c735eb 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple -from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,9 +79,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -149,8 +147,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -182,11 +179,7 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -202,12 +195,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -218,19 +211,13 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -245,12 +232,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -290,11 +277,7 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs + self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -309,12 +292,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -653,9 +636,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -685,8 +666,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 6293e081a..3f21133a0 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,21 +15,23 @@ # limitations under the License. +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional - -def _activation_balancer_loss(mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10): +def _activation_balancer_loss( + mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10, +): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -50,28 +52,32 @@ def _activation_balancer_loss(mean_pos: Tensor, """ loss_parts = [] - x_mean = mean_positive - mean_negative - x_mean_abs = (mean_positive + mean_negative + eps).detach() - x_rel_mean= x_mean / x_mean_abs + x_mean = mean_pos - mean_neg + x_mean_abs = (mean_pos + mean_neg + eps).detach() + x_rel_mean = x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = (-(1-min_positive) + min_positive) - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) + x_rel_mean_floor = -(1 - min_positive) + min_positive + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( + 1.0 / (2 * min_positive) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = - (1.0-max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) + x_rel_mean_ceil = -(1.0 - max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( + 1.0 / (1 - x_rel_mean_ceil) + ) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -82,43 +88,53 @@ def _activation_balancer_loss(mean_pos: Tensor, # 100% violated. loss_parts.append(max_abs_loss) - # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - num + # num if min_positive != 0.0: - - + pass class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -126,11 +142,16 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -163,29 +184,30 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -207,27 +229,26 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - def __init__(self, *args, - initial_scale: float = 1.0, - **kwargs): + + def __init__(self, *args, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -237,56 +258,67 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, - initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): @@ -297,45 +329,58 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), # noqa: F821 + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -364,12 +409,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -379,10 +428,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -400,6 +454,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -411,18 +466,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -491,8 +545,13 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -501,33 +560,40 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() - - def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -537,24 +603,37 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @@ -565,8 +644,13 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -576,17 +660,22 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -621,7 +710,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 2f6840166..123d448bb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,9 +78,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -179,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -203,8 +200,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -554,23 +550,16 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -733,9 +722,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -835,7 +822,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 2d5724d30..072d49d9c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -204,8 +204,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -272,9 +271,7 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -289,10 +286,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -415,9 +409,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -428,18 +420,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -449,10 +437,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -494,9 +479,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -528,9 +511,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -557,9 +540,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 318cd5094..008f40fb1 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,13 +272,9 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer( - x, x_lens, states - ) + emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) - if x.size(1) != ( - self.model.segment_length + self.model.right_context_length - ): + if x.size(1) != (self.model.segment_length + self.model.right_context_length): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 2375f5001..3612a2bfd 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -133,8 +133,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -170,9 +169,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -199,9 +198,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -273,9 +272,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index 2f019bcdb..ed6848879 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,9 +122,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index fed814f19..3601e1e11 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,8 +209,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -233,8 +232,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -566,11 +564,7 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -599,9 +593,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -782,9 +774,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -892,10 +882,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -908,8 +898,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 7af9cc3d7..830b37cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,9 +670,7 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -688,9 +686,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -892,9 +888,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1088,9 +1082,7 @@ def beam_search( t = 0 B = HypothesisList() - B.add( - Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max - ) + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) max_sym_per_utt = 20000 @@ -1130,9 +1122,7 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 7b6338948..3c4500087 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -107,6 +107,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -128,11 +129,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -142,6 +139,8 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -269,8 +268,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -293,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", @@ -375,6 +373,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -383,9 +389,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if ( @@ -450,10 +454,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -584,9 +585,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -597,18 +596,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -618,10 +613,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -678,9 +670,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -718,8 +708,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -757,9 +746,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 386248554..e522943c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,9 +75,7 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = ( - params.right_context + 2 - ) * params.subsampling_factor + 3 + self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -91,13 +89,11 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = ( - k2.RnntDecodingStream(decoding_graph) + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") @property def done(self) -> bool: @@ -126,13 +122,10 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min( - self.num_frames - self.num_processed_frames, chunk_length - ) + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames # noqa - + ret_length + self.num_processed_frames : self.num_processed_frames + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index f4355e8a0..49b82c433 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -58,7 +58,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, - padding_idx=blank_id, ) self.blank_id = blank_id self.unk_id = unk_id @@ -92,9 +91,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py new file mode 100755 index 000000000..a3ebe9d8c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -0,0 +1,527 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer""" + + def __init__(self, encoder: Conformer): + """ + Args: + encoder: + A Conformer encoder. + """ + super().__init__() + self.encoder = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder""" + + def __init__(self, decoder: Decoder): + super().__init__() + self.decoder = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + output = decoder_output.squeeze(1) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, inner_linear: nn.Linear, output_linear: nn.Linear): + super().__init__() + self.inner_linear = inner_linear + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.inner_linear(torch.tanh(logit)) + output = self.output_linear(nn.functional.relu(logit)) + + return output + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless3", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.inner_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + encoder = OnnxEncoder(encoder=model.encoder) + + decoder = OnnxDecoder(decoder=model.decoder) + + joiner = OnnxJoiner( + inner_linear=model.joiner.inner_linear, output_linear=model.joiner.output_linear + ) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index b5a151878..a19f9ab9a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -105,8 +105,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -192,9 +191,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 73b651b3f..2cca7fa27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,9 +130,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py new file mode 120000 index 000000000..66d63b807 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py new file mode 100755 index 000000000..8134d43f8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file + +./pruned_transducer_stateless/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index eb95827af..2ed1725b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -221,10 +220,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -292,9 +290,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,9 +377,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index dcf6dc42f..9e09200a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,14 +166,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index d2cae4f9f..f4b01fd06 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -269,9 +264,7 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -291,9 +284,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -349,9 +340,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -422,9 +411,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -436,9 +423,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -447,9 +432,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -459,10 +442,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -533,8 +513,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 399b11a29..cf4032027 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,8 +203,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -227,8 +226,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -562,9 +560,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -584,9 +580,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -777,9 +771,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -881,10 +873,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -897,8 +889,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -956,9 +947,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b7c2010f7..0280193ca 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -26,7 +26,9 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM from icefall.utils import ( DecodingResults, add_eos, @@ -45,6 +47,8 @@ def fast_beam_search_one_best( max_states: int, max_contexts: int, temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -86,6 +90,8 @@ def fast_beam_search_one_best( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + subtract_ilme=subtract_ilme, + ilme_scale=ilme_scale, ) best_path = one_best_decoding(lattice) @@ -426,6 +432,8 @@ def fast_beam_search( max_states: int, max_contexts: int, temperature: float = 1.0, + subtract_ilme: bool = False, + ilme_scale: float = 0.1, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -496,6 +504,17 @@ def fast_beam_search( ) logits = logits.squeeze(1).squeeze(1) log_probs = (logits / temperature).log_softmax(dim=-1) + if subtract_ilme: + ilme_logits = model.joiner( + torch.zeros_like( + current_encoder_out, device=current_encoder_out.device + ).unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + ilme_logits = ilme_logits.squeeze(1).squeeze(1) + ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) + log_probs -= ilme_scale * ilme_log_probs decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) @@ -580,9 +599,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) + decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( + 1, context_size + ) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,9 +794,7 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) else: self._data[key] = hyp @@ -793,9 +810,7 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -831,11 +846,22 @@ class HypothesisList(object): ans.add(hyp) # shallow copy return ans - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ hyps = list(self._data.items()) - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans @@ -990,9 +1016,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1004,9 +1028,7 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1054,6 +1076,420 @@ def modified_beam_search( ) +def modified_beam_search_lm_rescore( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + # get the best hyp with different lm_scale + for lm_scale in lm_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}" + tot_scores = am_scores.values + lm_scores * lm_scale + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + +def modified_beam_search_lm_rescore_LODR( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -1676,9 +2112,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values - + n_scale * ngram_lm_scores - + rnn_scale * rnn_lm_scores + am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1804,9 +2238,7 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax( - dim=-1 - ) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1816,9 +2248,7 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1841,9 +2271,7 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = ( - topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - ) + new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1862,17 +2290,20 @@ def modified_beam_search_ngram_rescoring( return ans -def modified_beam_search_rnnlm_shallow_fusion( +def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, + LODR_lm: NgramLm, + LODR_lm_scale: float, + LM: LmScorer, beam: int = 4, - return_timestamps: bool = False, ) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. Args: model (Transducer): @@ -1882,24 +2313,24 @@ def modified_beam_search_rnnlm_shallow_fusion( encoder_out_lens (torch.Tensor): A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM beam (int, optional): Beam size. Defaults to 4. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. + """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size + assert LM is not None + lm_scale = LM.lm_scale packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -1909,7 +2340,7 @@ def modified_beam_search_rnnlm_shallow_fusion( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device @@ -1921,7 +2352,8 @@ def modified_beam_search_rnnlm_shallow_fusion( # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) B = [HypothesisList() for _ in range(N)] for i in range(N): @@ -1929,18 +2361,19 @@ def modified_beam_search_rnnlm_shallow_fusion( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, + state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), - timestamp=[], + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram ) ) - rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for batch_size in batch_size_list: start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -1995,15 +2428,13 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop looks very similar to the one below. Here, we go through all top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that + LM will score those tokens given the LM states. Note that the variable `scores` is the LM score after seeing the new non-blank token. """ @@ -2023,24 +2454,305 @@ def modified_beam_search_rnnlm_shallow_fusion( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) - # forward RNNLM to get new states and scores + # forward NN LM to get new states and scores if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) - ) + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + ) # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +def modified_beam_search_lm_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = getattr(LM, "sos_id", 1) + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) count = 0 # index, used to locate score and lm states for i in range(batch_size): @@ -2067,15 +2779,15 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 new_hyp = Hypothesis( @@ -2103,6 +2815,6 @@ def modified_beam_search_rnnlm_shallow_fusion( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bc273d33b..9bac46004 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -378,6 +375,11 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0 @@ -439,9 +441,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +459,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +525,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -785,9 +781,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -811,9 +805,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1127,9 +1119,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1198,31 +1190,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1264,23 +1247,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1322,21 +1297,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1355,13 +1326,9 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1498,16 +1465,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 979a0e02e..c57514193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,11 +132,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -275,8 +271,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -300,7 +295,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -383,12 +378,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -397,9 +394,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -465,10 +460,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -608,9 +600,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -621,18 +611,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -642,10 +628,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -700,9 +683,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -740,8 +721,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -779,9 +759,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index ba91302ce..d44ed6f81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -59,7 +59,6 @@ class Decoder(nn.Module): self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) self.blank_id = blank_id @@ -107,15 +106,11 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( - -1 - ) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index f1a8ea589..984caf5f2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -173,8 +168,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -222,9 +216,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 6a9d08033..9f88bd029 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -56,13 +56,9 @@ class Joiner(nn.Module): """ if not is_jit_tracing(): assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 417c391d9..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,9 +66,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -152,9 +150,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 041a81f45..2d7f557ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,17 +72,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -118,9 +112,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -147,7 +139,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -158,9 +150,7 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -180,18 +170,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [ - group["initial_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -299,10 +285,9 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index f52cb22ab..013964720 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -168,8 +168,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -222,10 +221,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -293,9 +291,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,9 +378,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 8c572a9ef..963ebdc2d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,9 +89,7 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward( - factor, xgt0, below_threshold, above_threshold - ) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -137,7 +135,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -229,8 +227,7 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) - + self.eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -282,12 +279,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -301,9 +298,7 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear( - input, self.get_weight(), self.get_bias() - ) + return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) class ScaledConv1d(nn.Conv1d): @@ -331,12 +326,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -400,12 +395,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -476,9 +471,7 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter( - batch_dim=1, threshold=grad_norm_threshold - ) + self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) self._reset_parameters( initial_speed @@ -486,8 +479,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std - scale = self.hidden_size ** -0.5 + a = (3**0.5) * std + scale = self.hidden_size**-0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -559,15 +552,11 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append( - self._flat_weights[idx] * self._scales[idx].exp() - ) + flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) self._flatten_parameters(flat_weights) return flat_weights - def forward( - self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ): + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -663,16 +652,16 @@ class ActivationBalancer(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: if random.random() >= self.balance_prob: return x - else: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor / self.balance_prob, - self.min_abs, - self.max_abs, - ) + + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor / self.balance_prob, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -915,9 +904,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -947,8 +934,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1007,11 +994,11 @@ def _test_grad_filter(): print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + (x_out_grad**2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + (x.grad**2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index 9bcd2f9f9..e6e0fb1c8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,9 +153,7 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, decoder_out, project_input=False - ) + logits = model.joiner(current_encoder_out, decoder_out, project_input=False) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -172,14 +170,10 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk( - num_active_paths - ) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d76a03946..9c4a13606 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,11 +51,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -162,8 +158,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -271,9 +266,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -293,9 +286,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -351,9 +342,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -425,9 +414,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -438,9 +425,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -449,9 +434,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -461,10 +444,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -536,8 +516,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1947834bf..6c19f2cb0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,9 +96,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -210,8 +208,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to " - "be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -234,8 +231,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -258,8 +254,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -634,9 +629,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -649,14 +642,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -667,9 +655,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -837,9 +823,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -947,10 +931,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -963,8 +947,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 1df7f9ee5..b7735be85 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,10 +27,7 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import ( - OnTheFlyFeatures, - PrecomputedFeatures, -) +from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -167,9 +164,7 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -178,9 +173,7 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -250,9 +243,7 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 5784a78ba..b4804ecde 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,11 +79,7 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -192,8 +188,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -280,9 +275,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] if params.decoding_method == "fast_beam_search": @@ -312,10 +305,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -446,9 +436,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -481,8 +469,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -532,9 +519,7 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -567,8 +552,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 8025d6be1..7c62bfa58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -90,8 +91,41 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 -""" +(8) modified beam search (with LM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ +""" import argparse import logging @@ -116,15 +150,15 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall import LmScorer, NgramLm +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -206,6 +240,9 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring + - modified_beam_search_lm_shallow_fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -265,9 +302,9 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -307,7 +344,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -347,61 +384,58 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is rnn-lm. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is rnn-lm. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true """, ) + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -415,7 +449,10 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, - rnn_lm_model: torch.nn.Module = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + rnn_lm_model: Optional[RnnLmModel] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -449,6 +486,13 @@ def decode_one_batch( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It an FsaVec containing an acceptor. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -464,12 +508,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -478,9 +524,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -550,10 +594,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -617,6 +658,39 @@ def decode_one_batch( nbest_scale=params.nbest_scale, temperature=params.temperature, ) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -691,10 +765,7 @@ def decode_one_batch( return {key: hyps} else: return { - ( - f"beam_size_{params.beam_size}_" - f"temperature_{params.temperature}" - ): hyps + (f"beam_size_{params.beam_size}_temperature_{params.temperature}"): hyps } @@ -706,7 +777,10 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, - rnn_lm_model: torch.nn.Module = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + rnn_lm_model: Optional[RnnLmModel] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -730,6 +804,8 @@ def decode_dataset( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It's an FsaVec containing an acceptor. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -762,7 +838,10 @@ def decode_dataset( decoding_graph=decoding_graph, batch=batch, G=G, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -779,9 +858,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -792,18 +869,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -813,10 +886,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -899,6 +969,7 @@ def load_ngram_LM( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -916,6 +987,9 @@ def main(): "modified_beam_search", "fast_beam_search_with_nbest_rescoring", "fast_beam_search_with_nbest_rnn_rescoring", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_ngram_rescoring", ) params.res_dir = params.exp_dir / params.decoding_method @@ -939,15 +1013,26 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-temperature-{params.temperature}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -981,8 +1066,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1032,15 +1116,10 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if ( - params.decoding_method - == "fast_beam_search_with_nbest_rnn_rescoring" - ): + if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1065,15 +1144,42 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) rnn_lm_model = None else: decoding_graph = None word_table = None rnn_lm_model = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1100,7 +1206,10 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, G=G, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py new file mode 100755 index 000000000..9645b7801 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless3", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 47217ba05..f30c9df6a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -52,32 +52,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, It will generates 3 files: `encoder_jit_trace.pt`, `decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. - -(3) Export to ONNX format - -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./onnx_pretrained.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(4) Export `model.state_dict()` +(3) Export `model.state_dict()` ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ @@ -128,11 +103,7 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -214,29 +185,11 @@ def get_parser(): """, ) - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -375,210 +328,6 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - warmup = 1.0 - torch.onnx.export( - encoder_model, - (x, x_lens, warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "warmup"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) - - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "projected_encoder_out", - "projected_decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "projected_encoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -616,8 +365,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -646,31 +394,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx is True: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 11 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - elif params.jit is True: + if params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore @@ -715,9 +439,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 36f32c6b3..f3bd6284e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -30,7 +30,7 @@ class GigaSpeech: """ Args: manifest_dir: - It is expected to contain the following files:: + It is expected to contain the following files: - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz - gigaspeech_cuts_L_raw.jsonl.gz @@ -52,18 +52,14 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [ - (int(pattern.search(f).group(1)), f) for f in filenames - ] + idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) + return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 162f8c7db..0669284b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -142,10 +142,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -330,9 +329,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py index 6dba8e9fe..9f2cb6225 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/librispeech.py @@ -72,3 +72,12 @@ class LibriSpeech: f = self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" logging.info(f"About to get dev-other cuts from {f}") return load_manifest_lazy(f) + + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 7852f84e9..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,9 +84,7 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -190,9 +188,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index d03d1d7ef..5ca4173c1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -19,21 +19,70 @@ """ This script checks that exported onnx models produce the same output with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./pruned_transducer_stateless3/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - cpu_jit.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx """ import argparse import logging from icefall import is_module_available +from onnx_pretrained import OnnxModel -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - -import onnxruntime as ort import torch -ort.set_default_logger_severity(3) - def get_parser(): parser = argparse.ArgumentParser( @@ -68,174 +117,80 @@ def get_parser(): help="Path to the onnx joiner model", ) - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - return parser def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() ) def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 512] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 512) - projected_decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out ) - # Now test encoder_proj - joiner_encoder_proj_inputs = { - encoder_proj_input_name: encoder_out.numpy() - } - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ( - (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) - .abs() - .max() - ) - - # Now test decoder_proj - joiner_decoder_proj_inputs = { - decoder_proj_input_name: decoder_out.numpy() - } - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ( - (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) - .abs() - .max() + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() ) @@ -244,53 +199,41 @@ def main(): args = get_parser().parse_args() logging.info(vars(args)) - model = torch.jit.load(args.jit_filename) + torch_model = torch.jit.load(args.jit_filename) - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - ) - test_encoder(model, encoder_session) + test_encoder(torch_model, onnx_model) logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - ) - test_decoder(model, decoder_session) + test_decoder(torch_model, onnx_model) logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) + test_joiner(torch_model, onnx_model) logging.info("Finished checking ONNX models") +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) if __name__ == "__main__": torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py new file mode 100755 index 000000000..3b1c72cf1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless3/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from librispeech import LibriSpeech + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ea5d4e674..e10915086 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -18,35 +18,60 @@ This script loads ONNX models and uses them to decode waves. You can use the following command to get the exported models: -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. -Usage of this script: +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file ./pruned_transducer_stateless3/onnx_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav """ import argparse import logging import math -from typing import List +from typing import List, Tuple +import k2 import kaldifeat -import numpy as np import onnxruntime as ort -import sentencepiece as spm import torch import torchaudio from torch.nn.utils.rnn import pad_sequence @@ -79,23 +104,9 @@ def get_parser(): ) parser.add_argument( - "--joiner-encoder-proj-model-filename", + "--tokens", type=str, - required=True, - help="Path to the joiner encoder_proj onnx model. ", - ) - - parser.add_argument( - "--joiner-decoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner decoder_proj onnx model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -115,16 +126,122 @@ def get_parser(): help="The sample rate of the input sound file", ) - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - return parser +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -140,46 +257,31 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans def greedy_search( - decoder: ort.InferenceSession, - joiner: ort.InferenceSession, - joiner_encoder_proj: ort.InferenceSession, - joiner_decoder_proj: ort.InferenceSession, - encoder_out: np.ndarray, - encoder_out_lens: np.ndarray, - context_size: int, + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: - decoder: - The decoder model. - joiner: - The joiner model. - joiner_encoder_proj: - The joiner encoder projection model. - joiner_decoder_proj: - The joiner decoder projection model. + model: + The transducer model. encoder_out: - A 3-D tensor of shape (N, T, C) + A 3-D tensor of shape (N, T, joiner_dim) encoder_out_lens: A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. Returns: Return the decoded results for each utterance. """ - encoder_out = torch.from_numpy(encoder_out) - encoder_out_lens = torch.from_numpy(encoder_out_lens) - assert encoder_out.ndim == 3 + assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( @@ -189,15 +291,6 @@ def greedy_search( enforce_sorted=False, ) - projected_encoder_out = joiner_encoder_proj.run( - [joiner_encoder_proj.get_outputs()[0].name], - { - joiner_encoder_proj.get_inputs()[ - 0 - ].name: packed_encoder_out.data.numpy() - }, - )[0] - blank_id = 0 # hard-code to 0 batch_size_list = packed_encoder_out.batch_sizes.tolist() @@ -206,50 +299,27 @@ def greedy_search( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) + context_size = model.context_size hyps = [[blank_id] * context_size for _ in range(N)] - decoder_input_nodes = decoder.get_inputs() - decoder_output_nodes = decoder.get_outputs() - - joiner_input_nodes = joiner.get_inputs() - joiner_output_nodes = joiner.get_outputs() - decoder_input = torch.tensor( hyps, dtype=torch.int64, ) # (N, context_size) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - - projected_decoder_out = torch.from_numpy(projected_decoder_out) + decoder_out = model.run_decoder(decoder_input) offset = 0 for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = projected_encoder_out[start:end] - # current_encoder_out's shape: (batch_size, encoder_out_dim) + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) offset = end - projected_decoder_out = projected_decoder_out[:batch_size] + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) - logits = joiner.run( - [joiner_output_nodes[0].name], - { - joiner_input_nodes[0].name: current_encoder_out, - joiner_input_nodes[1].name: projected_decoder_out.numpy(), - }, - )[0] - logits = torch.from_numpy(logits).squeeze(1).squeeze(1) # logits'shape (batch_size, vocab_size) assert logits.ndim == 2, logits.shape @@ -266,17 +336,7 @@ def greedy_search( decoder_input, dtype=torch.int64, ) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - projected_decoder_out = torch.from_numpy(projected_decoder_out) + decoder_out = model.run_decoder(decoder_input) sorted_ans = [h[context_size:] for h in hyps] ans = [] @@ -292,39 +352,12 @@ def main(): parser = get_parser() args = parser.parse_args() logging.info(vars(args)) - - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - encoder = ort.InferenceSession( - args.encoder_model_filename, - sess_options=session_opts, + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, ) - decoder = ort.InferenceSession( - args.decoder_model_filename, - sess_options=session_opts, - ) - - joiner = ort.InferenceSession( - args.joiner_model_filename, - sess_options=session_opts, - ) - - joiner_encoder_proj = ort.InferenceSession( - args.joiner_encoder_proj_model_filename, - sess_options=session_opts, - ) - - joiner_decoder_proj = ort.InferenceSession( - args.joiner_decoder_proj_model_filename, - sess_options=session_opts, - ) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = "cpu" @@ -352,39 +385,33 @@ def main(): ) feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - - encoder_input_nodes = encoder.get_inputs() - encoder_out_nodes = encoder.get_outputs() - encoder_out, encoder_out_lens = encoder.run( - [encoder_out_nodes[0].name, encoder_out_nodes[1].name], - { - encoder_input_nodes[0].name: features.numpy(), - encoder_input_nodes[1].name: feature_lengths.numpy(), - }, - ) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) hyps = greedy_search( - decoder=decoder, - joiner=joiner, - joiner_encoder_proj=joiner_encoder_proj, - joiner_decoder_proj=joiner_decoder_proj, + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - context_size=args.context_size, ) s = "\n" + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += symbol_table[i] + return text.replace("▁", " ").strip() + for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" + words = token_ids_to_words(hyp) + s += f"{filename}:\n{words}\n" logging.info(s) logging.info("Decoding Done") if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 19b636a23..7c3dfc660 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -231,10 +230,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -302,9 +300,7 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,9 +387,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 1e7e808c7..a6540c584 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -87,7 +87,7 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: in_features=scaled_linear.in_features, out_features=scaled_linear.out_features, bias=True, # otherwise, it throws errors when converting to PNNX format - # device=weight.device, # Pytorch version before v1.9.0 does not has + # device=weight.device, # Pytorch version before v1.9.0 does not have # this argument. Comment out for now, we will # see if it will raise error for versions # after v1.9.0 @@ -234,9 +234,7 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = ( - scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - ) + scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -284,7 +282,7 @@ def convert_scaled_to_non_scaled( if not inplace: model = copy.deepcopy(model) - excluded_patterns = r"self_attn\.(in|out)_proj" + excluded_patterns = r"(self|src)_attn\.(in|out)_proj" p = re.compile(excluded_patterns) d = {} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 10bb44e00..3a1ecb7ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,11 +52,7 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import ( AttributeDict, setup_logger, @@ -163,8 +159,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -272,9 +267,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -294,9 +287,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -352,9 +343,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -426,9 +415,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -439,18 +426,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -460,10 +443,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -535,8 +515,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 66ffbd3ec..598fcf344 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,9 +90,7 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() os.remove(filename) @@ -147,9 +145,7 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -197,9 +193,7 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace( - encoder_layer, (x, pos_emb, src_key_padding_mask) - ) + jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) torch.onnx.export( encoder_layer, @@ -236,9 +230,7 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -322,9 +314,7 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -379,9 +369,7 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( - (onnx_y - torch_y).abs().max() - ) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py index e9dfe6d5e..42de2410a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py @@ -52,17 +52,9 @@ def test_scaled_conv2d(): torch.jit.script(conv2d) -def test_activation_balancer(): - act = ActivationBalancer( - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 - ) - torch.jit.script(act) - - def main(): test_scaled_conv1d() test_scaled_conv2d() - test_activation_balancer() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 44e96644a..fdafa5a87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,9 +92,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -163,8 +161,7 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", ) parser.add_argument( @@ -214,8 +211,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -238,8 +234,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -262,8 +257,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -672,9 +666,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -687,14 +679,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -705,9 +692,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -919,9 +904,7 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -967,8 +950,7 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False @@ -1083,10 +1065,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) @@ -1109,9 +1091,7 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest( - Path(args.manifest_dir) / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 4f043e5a6..79d919ab1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -109,10 +109,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./pruned_transducer_stateless4/decode.py \ @@ -306,8 +303,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -330,14 +326,14 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", # noqa + help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa ) parser.add_argument( @@ -413,12 +409,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -427,9 +425,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) if ( params.decoding_method == "fast_beam_search" @@ -485,10 +481,7 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -534,7 +527,6 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, @@ -566,9 +558,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[ - str, List[Tuple[str, List[str], List[str], List[float], List[float]]] -]: +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -643,9 +633,7 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append( - (cut_id, ref_words, hyp_words, time_ref, time_hyp) - ) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -654,9 +642,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -671,18 +657,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -693,10 +675,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -704,8 +683,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) @@ -722,9 +700,7 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format( - test_set_name - ) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -773,9 +749,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -812,9 +786,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -841,9 +815,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -902,9 +876,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index ce7518ceb..8f33f5b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -133,8 +134,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -183,9 +183,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -212,9 +212,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -262,6 +262,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not @@ -282,9 +283,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py new file mode 120000 index 000000000..9aa06f82f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 7af9ea9b8..ca3a023ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -451,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -472,10 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -547,9 +531,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +560,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cf32e565b..9bd7df401 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -239,8 +237,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -263,8 +260,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -621,11 +617,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,9 +657,7 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -680,14 +670,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -698,9 +683,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -879,9 +862,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -997,10 +978,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1013,8 +994,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 427b06294..8bbceec61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -32,7 +32,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask class Conformer(EncoderInterface): @@ -214,10 +214,7 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if ( - len(self._init_state) == 2 - and self._init_state[0].size(1) == left_context - ): + if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -439,9 +436,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -459,9 +454,7 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) self.norm_final = BasicNorm(d_model) @@ -527,9 +520,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module( - src, src_key_padding_mask=src_key_padding_mask - ) + conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.dropout(conv) # feed forward module @@ -802,9 +793,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -820,9 +809,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -848,9 +835,7 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward( - self, x: torch.Tensor, left_context: int = 0 - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1027,15 +1012,28 @@ class RelPositionMultiheadAttention(nn.Module): n == left_context + 2 * time1 - 1 ), f"{n} == {left_context} + 2 * {time1} - 1" # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, @@ -1118,9 +1116,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -1189,31 +1187,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1253,23 +1242,15 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1310,21 +1291,17 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) else: # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze( - 0 - ) | key_padding_mask.unsqueeze(1).unsqueeze(2) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill( - combined_mask, 0.0 - ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1336,13 +1313,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1481,16 +1454,12 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert ( - not self.training - ), "Cache should be None in training time" + assert not self.training, "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : ( # noqa - -right_context - ), + -(self.lorder + right_context) : (-right_context), # noqa ..., ] else: @@ -1666,9 +1635,7 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) .log() .item() ) @@ -1765,16 +1732,14 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 22bcdd88e..7a3e63218 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -87,22 +87,39 @@ Usage: --max-contexts 8 \ --max-states 64 -(8) modified beam search with RNNLM shallow fusion (with LG) +(8) modified beam search with RNNLM shallow fusion ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 4 \ - --max-contexts 4 \ - --rnn-lm-scale 0.4 \ - --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ """ @@ -128,10 +145,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, - modified_beam_search_rnnlm_shallow_fusion, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -139,7 +159,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -229,7 +248,8 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -271,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -303,8 +323,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -343,69 +362,49 @@ def get_parser(): ) parser.add_argument( - "--rnn-lm-scale", - type=float, - default=0.0, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", + "--use-shallow-fusion", type=str2bool, default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true """, ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -418,8 +417,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -448,6 +448,13 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -463,12 +470,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, @@ -477,9 +486,7 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -545,10 +552,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -565,15 +569,36 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -626,8 +651,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - rnnlm: Optional[RnnLmModel] = None, - rnnlm_scale: float = 1.0, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -646,6 +672,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -669,7 +697,6 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - logging.info(f"Decoding {batch_idx}-th batch") hyps_dict = decode_one_batch( params=params, @@ -678,8 +705,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, - rnnlm=rnnlm, - rnnlm_scale=rnnlm_scale, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -696,9 +724,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -709,18 +735,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -730,10 +752,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -751,6 +770,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -766,7 +786,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_rnnlm_shallow_fusion", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -787,14 +808,23 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -828,9 +858,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -857,9 +887,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -906,24 +936,34 @@ def main(): model.to(device) model.eval() - rnn_lm_model = None - rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": - rnn_lm_model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, ) - assert params.rnn_lm_avg == 1 + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, ) - rnn_lm_model.to(device) - rnn_lm_model.eval() + LM.to(device) + LM.eval() + + else: + LM = None if "fast_beam_search" in params.decoding_method: if "LG" in params.decoding_method: @@ -937,9 +977,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -968,8 +1006,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, - rnnlm=rnn_lm_model, - rnnlm_scale=rnn_lm_scale, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py new file mode 100755 index 000000000..e89d94d82 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt" + +cd exp +ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from onnxruntime.quantization import QuantType, quantize_dynamic +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless5", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index b2e5b430e..54f656859 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -133,8 +134,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -181,9 +181,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -210,9 +210,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -264,6 +264,7 @@ def main(): # it here. # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) @@ -280,9 +281,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py new file mode 120000 index 000000000..66d63b807 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py new file mode 100755 index 000000000..6f26e34b5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp/ \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless5/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..74a2210c3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -166,8 +166,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -198,10 +197,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -264,15 +262,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -344,9 +338,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 6fee9483e..5b15dcee7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -175,8 +175,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -284,9 +283,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -306,9 +303,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -364,9 +359,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_state(params.left_context, device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -438,9 +431,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -451,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -472,10 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -547,9 +531,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -576,9 +560,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 179d9372e..847c80ab0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,9 +89,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_model_arguments(parser: argparse.ArgumentParser): @@ -248,8 +246,7 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need " - "to be changed.", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( @@ -272,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -296,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -645,11 +640,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -690,9 +681,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -705,14 +694,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -723,9 +707,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -908,9 +890,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1023,16 +1003,16 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1045,8 +1025,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 53788b3f7..0667e7f61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,10 +90,7 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert ( - middle_output_layer >= 0 - and middle_output_layer < num_encoder_layers - ) + assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers output_layers.append(middle_output_layer) # The last layer is always needed. @@ -178,9 +175,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -362,9 +357,7 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -379,9 +372,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -656,9 +647,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) elif torch.equal(key, value): # encoder-decoder attention @@ -727,31 +718,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -790,9 +772,7 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -800,13 +780,9 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -840,13 +816,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -869,9 +841,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 74df04006..95534efef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -208,8 +208,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -267,9 +266,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out = layer_results[-1] hyps = [] @@ -285,10 +282,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -411,9 +405,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -424,18 +416,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -445,10 +433,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -490,9 +475,7 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -524,9 +507,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -553,9 +536,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index cff9c7377..b6190e8a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,11 +51,7 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -120,8 +116,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) return parser @@ -160,8 +155,7 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -209,9 +203,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 21409287c..86cf34877 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,9 +21,10 @@ import os from pathlib import Path import torch -from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned +from vq_utils import CodebookIndexExtractor + from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index 49b557814..b8440f90a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import Dict, List, Tuple import torch - from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -99,9 +98,7 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -124,9 +121,7 @@ def save_results( ) test_set_wers[key] = wer - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -155,9 +150,7 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = ( - params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" - ) + params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -190,9 +183,7 @@ def main(): params=params, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 55ce7b00d..4f9417c9f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,11 +22,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import ( - checkpoint_utils, - tasks, - utils, -) +from fairseq import checkpoint_utils, tasks, utils from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -51,9 +47,7 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / ( - params.teacher_model_id + ".pt" - ) + model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -151,9 +145,7 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( - [-1, 1] - ) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -163,9 +155,7 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask( - features, padding_mask - ) + padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -212,9 +202,7 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [ - self.processor.string(tok[tok != blank].int().cpu()) for tok in toks - ] + hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 7716d19cf..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,9 +69,7 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_speed=0.5 - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -180,9 +178,7 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -237,9 +233,7 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes( - middle_layer_output, codebook_indexes - ): + def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index f717d85fb..57753599a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,9 +101,7 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): @@ -203,8 +201,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -227,8 +224,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -569,9 +565,7 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [ - c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts - ] + cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -604,11 +598,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -655,9 +645,7 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all( - ~pruned_loss_is_finite - ): + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -670,14 +658,9 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -690,9 +673,7 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -873,9 +854,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -991,10 +970,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1007,8 +986,7 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 47cf2b14b..14ff86f23 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -69,7 +69,8 @@ class CodebookIndexExtractor: # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes self.vq_dir = ( - self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.params.exp_dir + / f"vq/{self.params.teacher_model_id}_layer{self.params.embedding_layer}_cb{self.params.num_codebooks}/" ) self.vq_dir.mkdir(parents=True, exist_ok=True) @@ -81,7 +82,10 @@ class CodebookIndexExtractor: # It's doesn't matter whether ori_manifest_dir is str or Path. # Set it to Path to be consistent. self.ori_manifest_dir = Path("./data/fbank/") - self.dst_manifest_dir = Path("./data/vq_fbank/") + self.dst_manifest_dir = Path( + f"./data/vq_fbank_layer" + + f"{self.params.embedding_layer}_cb{self.params.num_codebooks}/" + ) self.dst_manifest_dir.mkdir(parents=True, exist_ok=True) @@ -208,9 +212,7 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to( - dtype=torch.float - ) + yield data[start:end, :].to(self.params.device).to(dtype=torch.float) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -227,9 +229,7 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") @@ -240,23 +240,46 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) - cuts_vq = cuts_vq.sort_like(cuts_ori) - for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): - assert cut_vq.id == cut_ori.id - cut_ori.codebook_indexes = cut_vq.codebook_indexes + assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!" + + if set(cuts_vq.ids) == set(cuts_ori.ids): + # IDs match exactly + cuts_vq = cuts_vq.sort_like(cuts_ori) + for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): + assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id) + cut_ori.codebook_indexes = cut_vq.codebook_indexes + else: + # in case of ID mismatch, remap them + # get the mapping between audio and cut ID + logging + ori_id_map = {} + for id in cuts_ori.ids: + # some text normalization + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + ori_id_map[clean_id] = id + + for id in cuts_vq.ids: + if "sp" in id: + clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1] + else: + clean_id = "-".join(id.split("-")[:3]) + assert clean_id in ori_id_map, clean_id + cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[ + id + ].codebook_indexes CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) logging.info(f"Processed {subset}.") @@ -267,10 +290,12 @@ class CodebookIndexExtractor: Merge generated vq included manfiests and storage to self.dst_manifest_dir. """ for subset in self.params.subsets: - vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + vq_manifests = ( + f"{self.manifest_dir}/" + + f"with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" + ) dst_vq_manifest = ( - self.dst_manifest_dir - / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -330,9 +355,7 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = ( - f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - ) + ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py new file mode 100644 index 000000000..76cd56bbb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -0,0 +1,206 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import k2 +import torch + +from beam_search import Hypothesis, HypothesisList, get_hyps_shape + +# The force alignment problem can be formulated as finding +# a path in a rectangular lattice, where the path starts +# from the lower left corner and ends at the upper right +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). +# +# The notations `t` and `u` are from the paper +# https://arxiv.org/pdf/1211.3711.pdf +# +# Beam search is used to find the path with the highest log probabilities. +# +# It assumes the maximum number of symbols that can be +# emitted per frame is 1. + + +def batch_force_alignment( + model: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_list: List[List[int]], + beam_size: int = 4, +) -> List[int]: + """Compute the force alignment of a batch of utterances given their transcripts + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + This function is modified from `modified_beam_search` in beam_search.py. + We assume that the maximum number of sybmols per frame is 1. + + Args: + model: + The transducer model. + encoder_out: + A tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + ys_list: + A list of BPE token IDs list. We require that for each utterance i, + len(ys_list[i]) <= encoder_out_lens[i]. + beam_size: + Size of the beam used in beam search. + + Returns: + Return a list of frame indexes list for each utterance i, + where len(ans[i]) == len(ys_list[i]). + """ + assert encoder_out.ndim == 3, encoder_out.ndim + assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list)) + assert encoder_out.size(0) > 0, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + sorted_indices = packed_encoder_out.sorted_indices.tolist() + encoder_out_lens = encoder_out_lens.tolist() + ys_lens = [len(ys) for ys in ys_list] + sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices] + sorted_ys_lens = [ys_lens[i] for i in sorted_indices] + sorted_ys_list = [ys_list[i] for i in sorted_indices] + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size] + sorted_ys_lens = sorted_ys_lens[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs.reshape(-1) + ) # [batch][num_hyps*vocab_size] + + for i in range(batch_size): + for h, hyp in enumerate(A[i]): + pos_u = len(hyp.timestamp) + idx_offset = h * vocab_size + if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u): + # emit blank token + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + blank_id], + ys=hyp.ys[:], + timestamp=hyp.timestamp[:], + ) + B[i].add(new_hyp) + if pos_u < sorted_ys_lens[i]: + # emit non-blank token + new_token = sorted_ys_list[i][pos_u] + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + new_token], + ys=hyp.ys + [new_token], + timestamp=hyp.timestamp + [t], + ) + B[i].add(new_hyp) + + if len(B[i]) > beam_size: + B[i] = B[i].topk(beam_size, length_norm=True) + + B = B + finalized_B + sorted_hyps = [b.get_most_probable() for b in B] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + hyps = [sorted_hyps[i] for i in unsorted_indices] + ans = [] + for i, hyp in enumerate(hyps): + assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i]) + ans.append(hyp.timestamp) + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py new file mode 100755 index 000000000..8bcb56d62 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The script gets forced-alignments based on the modified_beam_search decoding method. +Both token-level alignments and word-level alignments are saved to the new cuts manifests. + +It loads a checkpoint and uses it to get the forced-alignments. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from alignment import batch_force_alignment +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp +from lhotse import CutSet +from lhotse.serialization import SequentialJsonlWriter +from lhotse.supervision import AlignmentItem + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--cuts-out-dir", + type=str, + default="data/fbank_ali_beam_search", + help="The dir to save the new cuts manifests with alignments", + ) + + add_model_arguments(parser) + + return parser + + +def align_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]: + """Get forced-alignments for one batch. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + token_list: + A list of token list. + word_list: + A list of word list. + token_time_list: + A list of timestamps list for tokens. + word_time_list. + A list of timestamps list for words. + + where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list), + len(token_list[i]) == len(token_time_list[i]), + and len(word_list[i]) == len(word_time_list[i]) + + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + texts = supervisions["text"] + ys_list: List[List[int]] = sp.encode(texts, out_type=int) + + frame_indexes = batch_force_alignment( + model, encoder_out, encoder_out_lens, ys_list, params.beam_size + ) + + token_list = [] + word_list = [] + token_time_list = [] + word_time_list = [] + for i in range(encoder_out.size(0)): + tokens = sp.id_to_piece(ys_list[i]) + words = texts[i].split() + token_time = convert_timestamp( + frame_indexes[i], params.subsampling_factor, params.frame_shift_ms + ) + word_time = parse_timestamp(tokens, token_time) + assert len(word_time) == len(words), (len(word_time), len(words)) + + token_list.append(tokens) + word_list.append(words) + token_time_list.append(token_time) + word_time_list.append(word_time) + + return token_list, word_list, token_time_list, word_time_list + + +def align_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + writer: SequentialJsonlWriter, +) -> None: + """Get forced-alignments for the dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + writer: + Writer to save the cuts with alignments. + """ + log_interval = 20 + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + token_list, word_list, token_time_list, word_time_list = align_one_batch( + params=params, model=model, sp=sp, batch=batch + ) + + cut_list = batch["supervisions"]["cut"] + for cut, token, word, token_time, word_time in zip( + cut_list, token_list, word_list, token_time_list, word_time_list + ): + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + token_ali = [ + AlignmentItem( + symbol=token[i], + start=round(token_time[i], ndigits=3), + duration=None, + ) + for i in range(len(token)) + ] + word_ali = [ + AlignmentItem( + symbol=word[i], start=round(word_time[i], ndigits=3), duration=None + ) + for i in range(len(word)) + ] + cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali} + writer.write(cut, flush=True) + + num_cuts += len(cut_list) + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + cuts_out_dir = Path(params.cuts_out_dir) + cuts_out_dir.mkdir(parents=True, exist_ok=True) + cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + + with CutSet.open_writer(cuts_out_path) as writer: + align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer) + + logging.info( + f"For dataset {params.dataset}, the cut manifest with framewise token alignments " + f"and word alignments are saved to {cuts_out_path}" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 06c5863f1..55a2493e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,6 +92,41 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.4 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ + """ @@ -115,9 +151,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, + modified_beam_search_ngram_rescoring, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -213,6 +253,8 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -272,9 +314,9 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -302,28 +344,49 @@ def get_parser(): ) parser.add_argument( - "--simulate-streaming", + "--use-shallow-fusion", type=str2bool, default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true """, ) parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], ) parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -336,6 +399,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -364,6 +430,13 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural net LM for shallow fusion. Only used when `--use-shallow-fusion` + set to true. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -378,24 +451,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -454,10 +510,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -474,6 +527,28 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -523,6 +598,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -541,6 +619,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -572,6 +652,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -588,9 +671,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -601,18 +682,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -622,10 +699,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -643,6 +717,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -657,6 +732,8 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -665,10 +742,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -679,13 +752,24 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -706,11 +790,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") @@ -718,9 +797,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -747,9 +826,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -796,6 +875,34 @@ def main(): model.to(device) model.eval() + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -808,9 +915,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -839,6 +944,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py new file mode 100644 index 000000000..4f64850b6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -0,0 +1,861 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + """ + This scripts test a libri model with libri BPE + on Gigaspeech. + """ + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech") + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 712dc8ce1..b085a1817 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -56,7 +56,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) self.blank_id = blank_id @@ -69,7 +68,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim//4, # group size == 4 + groups=decoder_dim // 4, # group size == 4 bias=False, ) @@ -87,13 +86,15 @@ class Decoder(nn.Module): y = y.to(torch.int64) # this stuff about clamp() is a temporary fix for a mismatch # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + if torch.jit.is_tracing(): + # This is for exporting to PNNX via ONNX + embedding_out = self.embedding(y) + else: + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py new file mode 100755 index 000000000..2f5d9e338 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" + +cd exp +ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 5744ea3ea..3e3160e7e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -176,8 +176,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) add_model_arguments(parser) @@ -215,9 +214,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -244,9 +243,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -295,7 +294,6 @@ def main(): if params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not @@ -316,9 +314,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py new file mode 100755 index 000000000..726a24809 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from decoder import Decoder +from gigaspeech import GigaSpeechAsrDataModule +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="""Embedding dimension in the 2 blocks of zipformer encoder + layers, comma separated + """, + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers,\ + comma separated; not the same as embedding dimension. + """, + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="""Unmasked dimensions in the encoders, relates to augmentation + during training. Must be <= each of encoder_dims. Empirically, less + than 256 seems to make performance worse. + """, + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to the BPE model. + This should be the bpe model of the original model + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.005, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have + # different behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py new file mode 100755 index 000000000..37edc0390 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. + +(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. + +(3) use the original model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path +from typing import Dict, List + +import sentencepiece as spm +import torch + +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model." + "If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + print("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py new file mode 100644 index 000000000..5c01d7190 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -0,0 +1,406 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train_{self.args.subset} cuts") + path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py new file mode 120000 index 000000000..fdfa6ce4b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index 81b0deba3..5af6dae25 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -30,6 +30,7 @@ Usage of this script: ./pruned_transducer_stateless7/jit_pretrained.py \ --nn-model-filename ./pruned_transducer_stateless7/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ /path/to/foo.wav \ /path/to/bar.wav """ @@ -92,10 +93,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -266,9 +266,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 7d8de5afe..62a4d22d6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -53,12 +53,9 @@ class Joiner(nn.Module): """ assert encoder_out.ndim == decoder_out.ndim assert encoder_out.ndim in (2, 4) - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 53cde6c6f..0e59b0f2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,14 +15,15 @@ # limitations under the License. +import random + import k2 import torch import torch.nn as nn -import random from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt from icefall.utils import add_sos -from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -65,7 +66,8 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, vocab_size, + encoder_dim, + vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -133,18 +135,16 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device - ) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py new file mode 100644 index 000000000..07c7126fa --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/multidataset.py @@ -0,0 +1,77 @@ +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import logging +import re +from pathlib import Path + +import lhotse +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, manifest_dir: str, cv_manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files: + + - librispeech_cuts_train-all-shuf.jsonl.gz + - gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz + + cv_manifest_dir: + It is expected to contain the following files: + + - cv-en_cuts_train.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + self.cv_manifest_dir = Path(cv_manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # LibriSpeech + logging.info(f"Loading LibriSpeech in lazy mode") + librispeech_cuts = load_manifest_lazy( + self.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + + # GigaSpeech + filenames = glob.glob( + f"{self.manifest_dir}/gigaspeech_XL_split_2000/gigaspeech_cuts_XL.*.jsonl.gz" + ) + + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + + sorted_filenames = [f[1] for f in idx_filenames] + + logging.info(f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode") + + gigaspeech_cuts = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + + # CommonVoice + logging.info(f"Loading CommonVoice in lazy mode") + commonvoice_cuts = load_manifest_lazy( + self.cv_manifest_dir / f"cv-en_cuts_train.jsonl.gz" + ) + + return CutSet.mux(librispeech_cuts, gigaspeech_cuts, commonvoice_cuts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 120000 index 000000000..20e334271 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless5/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py new file mode 100755 index 000000000..67585ee47 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" + +cd exp +ln -s pretrained-epoch-30-avg-9.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +2. Run this file + +./pruned_transducer_stateless7/onnx_decode.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from onnx_pretrained import greedy_search, OnnxModel + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, sp: spm.SentencePieceProcessor, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + hyps = [sp.decode(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + sp: + The BPE model. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, sp=sp, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + blank_id = sp.piece_to_id("") + assert blank_id == 0, blank_id + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset(dl=test_dl, model=model, sp=sp) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bb8b0a0e3..aa3cef338 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -1,4 +1,4 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../LICENSE for clarification regarding multiple authors # @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List, Optional, Union, Tuple, List -from lhotse.utils import fix_random_seed -import torch -from scaling import ActivationBalancer +import contextlib +import logging import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from scaling import ActivationBalancer from torch import Tensor from torch.optim import Optimizer -import logging -import contextlib - class BatchedOptimizer(Optimizer): @@ -37,13 +37,12 @@ class BatchedOptimizer(Optimizer): Args: params: """ + def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group): + def batched_params(self, param_group, group_params_names): """ This function returns (technically, yields) a list of of tuples (p, state), where @@ -65,106 +64,129 @@ class BatchedOptimizer(Optimizer): you can do: with self.batched_params(group["params"]) as batches: - for p, state in batches: + for p, state, p_names in batches: ... Args: group: a parameter group, which is a list of parameters; should be - one of self.groups. + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. """ - batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - for p in param_group: + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [ batches[key] for key in sorted(batches.keys()) ] - # pairs will contain pairs of (stacked_param, state), one for each batch - # in `batches`. - pairs = [] + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] - for batch in batches: + for batch, batch_names in zip(batches, batches_names): p = batch[0] # we arbitrarily store the state in the # state corresponding to the 1st parameter in the # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state)) + tuples.append((p_stacked, state, batch_names)) - yield pairs # <-- calling code will do the actual optimization here! + yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state), batch) in zip(pairs, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ + def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=True, ): - + assert parameters_names is not None, ( + "Please prepare parameters_names," + "which is a List[List[str]]. Each List[str] is for a group" + "and each str is for a parameter" + ) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -179,11 +201,13 @@ class ScaledAdam(BatchedOptimizer): ) super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -198,20 +222,23 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - for group in self.param_groups: - with self.batched_params(group["params"]) as batches: + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) - for p, state in batches: + for p, state, _ in batches: # Perform optimization step. # grad is not going to be None, we handled that when creating the batches. grad = p.grad @@ -225,13 +252,9 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) - return loss - def _init_state(self, - group: dict, - p: Tensor, - state: dict): + def _init_state(self, group: dict, p: Tensor, state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -247,7 +270,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {'device':p.device, 'dtype':p.dtype} + kwargs = {"device": p.device, "dtype": p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -255,48 +278,45 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size - numel = p.numel() - if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, - **kwargs) - + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale(self, - group: dict, - pairs: List[Tuple[Tensor, dict]]) -> float: + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. Args: group: the parameter group, an item in self.param_groups - pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad - (1st dim is batch dim) and state is the state-dict where optimization parameters are kept. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". """ - assert len(pairs) >= 1 + assert len(tuples) >= 1 clipping_scale = group["clipping_scale"] - (first_p, first_state) = pairs[0] + (first_p, first_state, _) = tuples[0] step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other @@ -305,7 +325,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state) in pairs: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -314,57 +334,131 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"])**2).sum() + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() tot_norm = tot_sumsq.sqrt() - if not "model_norms" in first_state: - first_state["model_norms"] = torch.zeros(clipping_update_period, - device=p.device) + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to('cpu') + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") quartiles = [] for n in range(0, 5): - index = min(clipping_update_period - 1, - (clipping_update_period // 4) * n) + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state else 0.0) + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) first_state["num_clipped"] = 0 - quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) - logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except: - logging.info("Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?") + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) return 1.0 - ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) return ans + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter wihch dominanting tot_sumsq. - def _step_one_batch(self, - group: dict, - p: Tensor, - state: dict, - clipping_scale: float): + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummpy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -391,17 +485,18 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True) + dim=list(range(1, p.ndim)), keepdim=True + ) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), - keepdim=True).sqrt()) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) - if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -411,24 +506,21 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - - def _size_update(self, - group: dict, - scale_grads: Tensor, - p: Tensor, - state: dict) -> None: + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -443,25 +535,28 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` - alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) - is_too_small = (param_rms < param_min_rms) - is_too_large = (param_rms > param_max_rms) + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -469,13 +564,9 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1-beta1)) + delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, - group: dict, - p: Tensor, - state: dict): + def _step(self, group: dict, p: Tensor, state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,8 +587,7 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=(1-beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -509,17 +599,13 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - - def _step_scalar(self, - group: dict, - p: Tensor, - state: dict): + def _step_scalar(self, group: dict, p: Tensor, state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -531,8 +617,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, - value=1-beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -540,12 +625,11 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr*(1-beta1)) + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) - class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -555,18 +639,14 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [ - group["base_lr"] for group in optimizer.param_groups - ] + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] self.epoch = 0 self.batch = 0 @@ -680,13 +760,15 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) - warmup_factor = (1.0 if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -716,6 +798,47 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") +def _plot_eden_lr(): + import matplotlib.pyplot as plt + + m = torch.nn.Linear(100, 100) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in m.named_parameters()] + ) + + for lr_epoch in [4, 10, 100]: + for lr_batch in [100, 400]: + optim = ScaledAdam( + m.parameters(), lr=0.03, parameters_names=parameters_names + ) + scheduler = Eden( + optim, lr_batches=lr_batch, lr_epochs=lr_epoch, verbose=True + ) + lr = [] + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(500): + lr.append(scheduler.get_lr()) + + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + plt.plot(lr, label=f"lr_epoch:{lr_epoch}, lr_batch:{lr_batch}") + + plt.legend() + plt.savefig("lr.png") + + # This is included mostly as a baseline for ScaledAdam. class Eve(Optimizer): """ @@ -745,13 +868,14 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam\: A Method for Stochastic Optimization: + .. _Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ + def __init__( self, params, @@ -766,17 +890,11 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -812,9 +930,7 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -841,7 +957,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -852,30 +968,31 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg/denom) * step_size - logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") - + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) return loss def _test_scaled_adam(hidden_dim: int): import timeit + from scaling import ScaledLinear + E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - #device = torch.device('cuda') - device = torch.device('cpu') + # device = torch.device('cuda') + device = torch.device("cpu") dtype = torch.float32 fix_random_seed(42) @@ -889,83 +1006,97 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential(Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] - if iter == 0: optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - #if epoch == 100 and iter in [2,3]: + # if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - #if epoch == 130: + # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x,y) in enumerate(train_pairs): + for n, (x, y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y)**2).mean() * 100.0 + loss = ((y_out - y) ** 2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - #diagnostic.print_diagnostics() + # diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - #logging.info("state dict = ", scheduler.state_dict()) - #logging.info("optim state_dict = ", optim.state_dict()) + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") - if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) logging.info(s) import sys + if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: hidden_dim = 200 - _test_scaled_adam(hidden_dim) - _test_eden() + # _test_scaled_adam(hidden_dim) + # _test_eden() + _plot_eden_lr() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 7fe1e681a..d05bafcfb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -177,8 +177,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -209,10 +208,9 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel ans.append(wave[0]) return ans @@ -275,15 +273,11 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) num_waves = encoder_out.size(0) hyps = [] @@ -355,9 +349,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 50cedba56..30a737061 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections +import logging +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging -import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,27 +32,24 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -65,14 +62,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -83,71 +88,76 @@ def _compute_scale_factor(x: Tensor, else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) return below_threshold - above_threshold -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor - class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ + @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) + xgt0 = x > 0 ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -155,18 +165,24 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -179,30 +195,32 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - is_same, = ctx.saved_tensors + (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None -def random_clamp(x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0): + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -215,6 +233,7 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ + @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -223,35 +242,37 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None + return ( + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), + None, + ) else: return ans_grad, None + class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - def __init__(self, - min_abs: float = 5.0e-06): + + def __init__(self, min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, - x: Tensor): - if torch.jit.is_scripting() or not self.training: + def forward(self, x: Tensor): + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return x else: return RandomGradFunction.apply(x, self.min_abs) - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -267,7 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -276,10 +297,8 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): - if torch.jit.is_scripting(): +def softmax(x: Tensor, dim: int): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x.softmax(dim) return SoftmaxFunction.apply(x, dim) @@ -288,20 +307,18 @@ def softmax(x: Tensor, class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -311,15 +328,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -385,15 +407,12 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -412,16 +431,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -440,13 +454,10 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -486,18 +497,19 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -515,12 +527,10 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) - - + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count @@ -535,26 +545,35 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor(x, self.channel_dim, - self.min_positive, self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor) + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) else: sign_factor = None - - scale_factor = _compute_scale_factor(x, self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor) + scale_factor = _compute_scale_factor( + x.detach(), + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( - x, scale_factor, sign_factor, self.channel_dim, + x, + scale_factor, + sign_factor, + self.channel_dim, ) else: return _no_op(x) @@ -594,13 +613,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -630,19 +648,17 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float) -> Tensor: + def forward( + ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float + ) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -650,9 +666,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -661,25 +676,28 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float,float]], - grad_scale: float): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -714,8 +732,7 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -735,19 +752,21 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, 'min_prob') and random.random() < 0.25: + if hasattr(self, "min_prob") and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - self.whitening_limit, - self.grad_scale) + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) class WithLoss(torch.autograd.Function): @@ -755,20 +774,24 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device) + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + ) + + def with_loss(x, y): - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x # returns x but adds y.sum() to the loss function. return WithLoss.apply(x, y) def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x else: # a no-op function that will have a node in the autograd graph, @@ -783,6 +806,7 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -803,13 +827,14 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ + def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -825,7 +850,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) + self.register_buffer("max_eig_direction", direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -833,12 +858,13 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + or torch.jit.is_tracing() + ): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -848,7 +874,9 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -861,7 +889,9 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -869,17 +899,16 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - - def _set_direction(self, - direction: Tensor): + def _set_direction(self, direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -889,40 +918,39 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs - - class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -950,7 +978,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -959,7 +987,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -972,12 +1002,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class DoubleSwish(torch.nn.Module): @@ -985,12 +1015,11 @@ class DoubleSwish(torch.nn.Module): """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) - def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1002,11 +1031,9 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale for _ in range(4): y = m(x) @@ -1031,11 +1058,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1049,7 +1074,6 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1077,9 +1101,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1111,8 +1133,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1124,30 +1146,27 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 8d357b15f..86067b04f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -22,15 +22,101 @@ BasicNorm is replaced by a module with `exp` removed. """ import copy -from typing import List +from typing import List, Tuple import torch import torch.nn as nn -from scaling import ( - ActivationBalancer, - BasicNorm, - Whiten, -) +from scaling import ActivationBalancer, BasicNorm, Whiten +from zipformer import PoolingModule + + +class PoolingModuleNoProj(nn.Module): + def forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x = x.cumsum(dim=0) # (T, N, C) + x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) + # Cumulated numbers of frames from start + cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) + cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + + cached_len = cached_len + x.size(0) + cached_avg = x[-1] + + return x, cached_len, cached_avg + + +class PoolingModuleWithProj(nn.Module): + def __init__(self, proj: torch.nn.Module): + super().__init__() + self.proj = proj + self.pooling = PoolingModuleNoProj() + + def forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) + return self.proj(x), cached_len, cached_avg + + def streaming_forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) + return self.proj(x), cached_len, cached_avg class NonScaledNorm(nn.Module): @@ -57,7 +143,7 @@ class NonScaledNorm(nn.Module): def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) + assert isinstance(basic_norm, BasicNorm), type(basic_norm) norm = NonScaledNorm( num_channels=basic_norm.num_channels, eps_exp=basic_norm.eps.data.exp().item(), @@ -66,6 +152,11 @@ def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: return norm +def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj: + assert isinstance(pooling, PoolingModule), type(pooling) + return PoolingModuleWithProj(proj=pooling.proj) + + # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # get_submodule was added to nn.Module at v1.9.0 def get_submodule(model, target): @@ -87,6 +178,7 @@ def get_submodule(model, target): def convert_scaled_to_non_scaled( model: nn.Module, inplace: bool = False, + is_pnnx: bool = False, ): """ Args: @@ -95,6 +187,8 @@ def convert_scaled_to_non_scaled( inplace: If True, the input model is modified inplace. If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. Return: Return a model without scaled layers. """ @@ -107,6 +201,8 @@ def convert_scaled_to_non_scaled( d[name] = convert_basic_norm(m) elif isinstance(m, (ActivationBalancer, Whiten)): d[name] = nn.Identity() + elif isinstance(m, PoolingModule) and is_pnnx: + d[name] = convert_pooling_module(m) for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py new file mode 100755 index 000000000..081f7ba1a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script compares the word-level alignments generated based on modified_beam_search decoding +(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated +by torchaudio framework (in ./add_alignments.sh). + +Usage: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search + +And the you can run: + +./pruned_transducer_stateless7/test_compute_ali.py \ + --cuts-out-dir ./data/fbank_ali_test \ + --cuts-ref-dir ./data/fbank_ali_torch \ + --dataset train-clean-100 +""" +import argparse +import logging +from pathlib import Path + +import torch +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--cuts-out-dir", + type=Path, + default="./data/fbank_ali", + help="The dir that saves the generated cuts manifests with alignments", + ) + + parser.add_argument( + "--cuts-ref-dir", + type=Path, + default="./data/fbank_ali_torch", + help="The dir that saves the reference cuts manifests with alignments", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset: + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + + cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + + logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}") + cuts_out = load_manifest(cuts_out_jsonl) + cuts_ref = load_manifest(cuts_ref_jsonl) + cuts_ref = cuts_ref.sort_like(cuts_out) + + all_time_diffs = [] + for cut_out, cut_ref in zip(cuts_out, cuts_ref): + time_out = [ + ali.start + for ali in cut_out.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + time_ref = [ + ali.start + for ali in cut_ref.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + assert len(time_out) == len(time_ref), (len(time_out), len(time_ref)) + diff = [ + round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref) + ] + all_time_diffs += diff + + all_time_diffs = torch.tensor(all_time_diffs) + logging.info( + f"For the word-level alignments abs difference on dataset {args.dataset}, " + f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py index db7fb7b3e..cdf914df3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py @@ -20,19 +20,21 @@ To run this file, do: cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless4/test_model.py + python ./pruned_transducer_stateless7/test_model.py """ +import torch + +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model -def test_model_1(): +def test_model(): params = get_params() params.vocab_size = 500 params.blank_id = 0 params.context_size = 2 params.num_encoder_layers = "2,4,3,2,4" - # params.feedforward_dims = "1024,1024,1536,1536,1024" params.feedforward_dims = "1024,1024,2048,2048,1024" params.nhead = "8,8,8,8,8" params.encoder_dims = "384,384,384,384,384" @@ -47,9 +49,19 @@ def test_model_1(): num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + def main(): - test_model_1() + test_model() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py new file mode 100644 index 000000000..2440d267c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file is to test that models can be exported to onnx. +""" +import os + +from icefall import is_module_available + +if not is_module_available("onnxruntime"): + raise ValueError("Please 'pip install onnxruntime' first.") + +import onnxruntime as ort +import torch +from scaling_converter import convert_scaled_to_non_scaled +from zipformer import ( + Conv2dSubsampling, + RelPositionalEncoding, + Zipformer, + ZipformerEncoder, + ZipformerEncoderLayer, +) + +ort.set_default_logger_severity(3) + + +def test_conv2d_subsampling(): + filename = "conv2d_subsampling.onnx" + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_embed = Conv2dSubsampling(num_features, d_model) + encoder_embed.eval() + encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True) + + torch.onnx.export( + encoder_embed, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + + onnx_y = session.run(["y"], inputs)[0] + + onnx_y = torch.from_numpy(onnx_y) + torch_y = encoder_embed(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + os.remove(filename) + + +def test_rel_pos(): + filename = "rel_pos.onnx" + + opset_version = 13 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + + torch.onnx.export( + encoder_pos, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["pos_emb"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "pos_emb": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + onnx_pos_emb = session.run(["pos_emb"], inputs) + onnx_pos_emb = torch.from_numpy(onnx_pos_emb[0]) + + torch_pos_emb = encoder_pos(x) + assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( + (onnx_pos_emb - torch_pos_emb).abs().max() + ) + print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum()) + + os.remove(filename) + + +def test_zipformer_encoder_layer(): + filename = "zipformer_encoder_layer.onnx" + opset_version = 13 + N = 30 + T = 50 + + d_model = 384 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + + x = torch.rand(N, T, d_model) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x = x.permute(1, 0, 2) + pos_emb = encoder_pos(x) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder_layer.eval() + encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) + + torch.onnx.export( + encoder_layer, + (x, pos_emb), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder_layer(x, pos_emb) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer_encoder(): + filename = "zipformer_encoder.onnx" + + opset_version = 13 + N = 3 + T = 15 + + d_model = 512 + attention_dim = 192 + nhead = 8 + feedforward_dim = 1024 + dropout = 0.1 + cnn_module_kernel = 31 + pos_dim = 4 + num_encoder_layers = 12 + + warmup_batches = 4000.0 + warmup_begin = warmup_batches / (num_encoder_layers + 1) + warmup_end = warmup_batches / (num_encoder_layers + 1) + + x = torch.rand(N, T, d_model) + + encoder_layer = ZipformerEncoderLayer( + d_model, + attention_dim, + nhead, + feedforward_dim, + dropout, + cnn_module_kernel, + pos_dim, + ) + encoder = ZipformerEncoder( + encoder_layer, num_encoder_layers, dropout, warmup_begin, warmup_end + ) + encoder.eval() + encoder = convert_scaled_to_non_scaled(encoder, inplace=True) + + # jit_model = torch.jit.trace(encoder, (pos_emb)) + + torch_y = encoder(x) + + torch.onnx.export( + encoder, + (x), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = encoder(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_zipformer(): + filename = "zipformer.onnx" + opset_version = 11 + N = 3 + T = 15 + num_features = 80 + x = torch.rand(N, T, num_features) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + + zipformer = Zipformer(num_features=num_features) + zipformer.eval() + zipformer = convert_scaled_to_non_scaled(zipformer, inplace=True) + + # jit_model = torch.jit.trace(zipformer, (x, x_lens)) + torch.onnx.export( + zipformer, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["y", "y_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "y": {0: "N", 1: "T"}, + "y_lens": {0: "N"}, + }, + ) + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: x_lens.numpy(), + } + onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_y_lens = torch.from_numpy(onnx_y_lens) + + torch_y, torch_y_lens = zipformer(x, x_lens) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + + assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( + (onnx_y_lens - torch_y_lens).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + print(onnx_y_lens, torch_y_lens) + + os.remove(filename) + + +@torch.no_grad() +def main(): + test_conv2d_subsampling() + test_rel_pos() + test_zipformer_encoder_layer() + test_zipformer_encoder() + test_zipformer() + + +if __name__ == "__main__": + torch.manual_seed(20221011) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8927be227..1b179ceff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -59,7 +60,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer +from multidataset import MultiDataset from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -71,6 +72,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -79,24 +81,26 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -def set_batch_count( - model: Union[nn.Module, DDP], batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for module in model.modules(): - if hasattr(module, 'batch_count'): + if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -126,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -134,7 +138,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): type=str, default="192,192,192,192,192", help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""" + not the same as embedding dimension.""", ) parser.add_argument( @@ -143,7 +147,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="256,256,256,256,256", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse." + " worse.", ) parser.add_argument( @@ -248,10 +252,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -274,8 +275,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -298,8 +298,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -377,6 +376,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-multidataset", + type=str2bool, + default=False, + help="Whether to use multidataset to train.", + ) + add_model_arguments(parser) return parser @@ -429,6 +435,8 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -451,11 +459,14 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) + return tuple(map(int, s.split(","))) + encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), @@ -479,7 +490,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -496,7 +507,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -648,11 +659,18 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -682,27 +700,24 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -878,7 +893,9 @@ def train_one_epoch( if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -888,8 +905,8 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + - (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -900,16 +917,14 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( @@ -921,7 +936,9 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -997,12 +1014,18 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), - lr=params.base_lr, - clipping_scale=2.0) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1020,7 +1043,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1029,10 +1052,14 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + if params.use_multidataset: + multidataset = MultiDataset(params.manifest_dir, params.cv_manifest_dir) + train_cuts = multidataset.train_cuts() + else: + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1043,7 +1070,30 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1062,7 +1112,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if not params.use_multidataset and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, @@ -1071,8 +1121,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1193,7 +1242,9 @@ def scan_pessimistic_batches_for_oom( ) display_and_save_batch(batch, params=params, sp=sp) raise - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c14066d38..5b75b8d35 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -16,32 +16,35 @@ # limitations under the License. import copy -import math -import warnings import itertools -from typing import List, Optional, Tuple, Union import logging -import torch +import math import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) from scaling import ( ActivationBalancer, BasicNorm, - MaxEig, DoubleSwish, - ScaledConv1d, - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - Whiten, Identity, + MaxEig, + ScaledConv1d, + Whiten, _diag, - random_clamp, penalize_abs_values_gt, + random_clamp, softmax, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask from icefall.dist import get_rank +from icefall.utils import is_jit_tracing, make_pad_mask class Zipformer(EncoderInterface): @@ -78,7 +81,6 @@ class Zipformer(EncoderInterface): super(Zipformer, self).__init__() self.num_features = num_features - self.encoder_unmasked_dims = encoder_unmasked_dims assert 0 < encoder_dims[0] <= encoder_dims[1] self.encoder_dims = encoder_dims self.encoder_unmasked_dims = encoder_unmasked_dims @@ -89,7 +91,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u,d in zip(encoder_unmasked_dims, encoder_dims): + for u, d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -97,9 +99,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], - dropout=dropout) - + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -123,13 +125,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -139,10 +141,9 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor) - + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) def _get_layer_skip_dropout_prob(self): if not self.training: @@ -166,86 +167,100 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i-1] <= z[i]: + if i <= 1 or z[i - 1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i-2, -1, -1): + for j in range(i - 2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dims[j], - self.encoder_dims[i-1], - min_weight=(0.0,0.25))) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks( - self, - x: torch.Tensor) -> List[float]: + def get_feature_masks(self, x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape (num_frames, batch_size, encoder_dims0) """ num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: - return [ 1.0 ] * num_encoders + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): + return [1.0] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = (num_frames0 + max_downsampling_factor - 1) - + num_frames_max = num_frames0 + max_downsampling_factor - 1 feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, - device=x.device) > - feature_mask_dropout_prob).to(x.dtype) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = (max_downsampling_factor // ds) + upsample_factor = max_downsampling_factor // ds - frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, - batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1)) + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], - dtype=x.dtype, device=x.device) + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, + x: torch.Tensor, + x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -271,18 +286,22 @@ class Zipformer(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module(x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + ) outputs.append(x) x = self.downsample_output(x) @@ -312,15 +331,16 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ + def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -330,29 +350,24 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, attention_dim, nhead, pos_dim, dropout=0.0, + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward2 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - self.feed_forward3 = FeedforwardModule(d_model, - feedforward_dim, - dropout) + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module1 = ConvolutionModule(d_model, - cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, - cnn_module_kernel) + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -360,17 +375,18 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, - min_positive=0.45, max_positive=0.55, + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return self.bypass_scale if random.random() < 0.1: # ensure we get grads if self.bypass_scale becomes out of range @@ -382,15 +398,16 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = (initial_clamp_min - - (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return 0.0 warmup_period = 2000.0 initial_dropout_rate = 0.2 @@ -398,8 +415,9 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return (initial_dropout_rate - - (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) def forward( self, @@ -434,12 +452,12 @@ class ZipformerEncoderLayer(nn.Module): dynamic_dropout = self.get_dynamic_dropout_rate() # pooling module - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) - elif random.random() > dynamic_dropout: + elif random.random() >= dynamic_dropout: src = src + self.pooling(src, key_padding_mask=src_key_padding_mask) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): src_att, attn_weights = self.self_attn( src, pos_emb=pos_emb, @@ -460,7 +478,7 @@ class ZipformerEncoderLayer(nn.Module): src, src_key_padding_mask=src_key_padding_mask ) else: - use_self_attn = random.random() > dynamic_dropout + use_self_attn = random.random() >= dynamic_dropout if use_self_attn: src_att, attn_weights = self.self_attn( src, @@ -470,7 +488,7 @@ class ZipformerEncoderLayer(nn.Module): ) src = src + src_att - if random.random() > dynamic_dropout: + if random.random() >= dynamic_dropout: src = src + self.conv_module1( src, src_key_padding_mask=src_key_padding_mask ) @@ -479,7 +497,7 @@ class ZipformerEncoderLayer(nn.Module): if use_self_attn: src = src + self.self_attn.forward2(src, attn_weights) - if random.random() > dynamic_dropout: + if random.random() >= dynamic_dropout: src = src + self.conv_module2( src, src_key_padding_mask=src_key_padding_mask ) @@ -508,13 +526,14 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ + def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -528,8 +547,7 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, - dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -538,15 +556,13 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1. / num_layers) * (warmup_end - warmup_begin) + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin - def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -579,12 +595,14 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -604,11 +622,12 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) return ans - def forward( self, src: Tensor, @@ -639,8 +658,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): layers_to_drop = [] else: rnd_seed = src.numel() + random.randint(0, 1000) @@ -649,7 +667,7 @@ class ZipformerEncoder(nn.Module): output = output * feature_mask for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if i in layers_to_drop: continue output = mod( @@ -670,28 +688,27 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - def __init__(self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int): + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) - - def forward(self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -718,42 +735,43 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds,::ds] + mask = mask[::ds, ::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + feature_mask=feature_mask, + mask=mask, + src_key_padding_mask=src_key_padding_mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = src[: src_orig.shape[0]] return self.out_combiner(src_orig, src) + class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, - in_channels: int, - out_channels: int, - downsample: int): + + def __init__(self, in_channels: int, out_channels: int, downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) else: self.extra_proj = None self.downsample = downsample - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -763,20 +781,18 @@ class AttentionDownsample(torch.nn.Module): ds = self.downsample d_seq_len = (seq_len + ds - 1) // ds - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + # Pad to an exact multiple of self.downsample, could be 0 for onnx-export-compatibility + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -795,14 +811,12 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - def __init__(self, - num_channels: int, - upsample: int): + + def __init__(self, num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, - src: Tensor) -> Tensor: + def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -815,6 +829,7 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src + class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -822,6 +837,7 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 + class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -831,18 +847,14 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - def __init__(self, - dim1: int, - dim2: int, - min_weight: Tuple[float] = (0., 0.)): + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, - src1: Tensor, - src2: Tensor) -> Tensor: + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -852,11 +864,15 @@ class SimpleCombiner(torch.nn.Module): assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) weight1 = self.weight1 - if not torch.jit.is_scripting(): - if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): - weight1 = weight1.clamp(min=self.min_weight[0], - max=1.0-self.min_weight[1]) - + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -869,12 +885,9 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] - return src1 + src2 - - class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -888,15 +901,20 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self.extend_pe(torch.tensor(0.0).expand(max_len)) def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" @@ -905,9 +923,7 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -955,7 +971,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -992,34 +1007,43 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert ( - self.head_dim * num_heads == attention_dim - ), (self.head_dim, num_heads, attention_dim) + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = (2 * attention_dim + # query, key - attention_dim // 2 + # value - pos_dim * num_heads) # positional encoding query + in_proj_dim = ( + 2 * attention_dim # query, key + + attention_dim // 2 # value + + pos_dim * num_heads # positional encoding query + ) - self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, - initial_scale=self.head_dim**-0.25) + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - self.whiten_keys = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, - initial_scale=0.05) + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1031,14 +1055,16 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, - initial_scale=0.05) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025) - + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) def forward( self, @@ -1098,7 +1124,6 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def multi_head_attention_forward( self, x_proj: Tensor, @@ -1158,24 +1183,21 @@ class RelPositionMultiheadAttention(nn.Module): pos_dim = self.pos_dim # positional-encoding dim per head assert ( head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" # self-attention - q = x_proj[...,0:attention_dim] - k = x_proj[...,attention_dim:2*attention_dim] + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] value_dim = attention_dim // 2 - v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[...,2*attention_dim+value_dim:] - + p = x_proj[..., 2 * attention_dim + value_dim :] k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1195,31 +1217,22 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) + raise RuntimeError("The size of the 2D attn_mask is not correct.") elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) + raise RuntimeError("The size of the 3D attn_mask is not correct.") else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: warnings.warn( "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) @@ -1230,7 +1243,6 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1239,13 +1251,10 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1256,18 +1265,31 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), - (pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2)-pos_weights.stride(3), - pos_weights.stride(3)), - storage_offset=pos_weights.stride(3) * (seq_len - 1)) - + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be too large. # It incurs a penalty if any of them has an absolute value greater than 50.0. @@ -1275,27 +1297,22 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt(attn_output_weights, - limit=25.0, - penalty=1.0e-04) - + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, seq_len, seq_len ) - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - seq_len, - seq_len, - ] - if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) else: - attn_output_weights += attn_mask + attn_output_weights = attn_output_weights + attn_mask if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( @@ -1315,25 +1332,49 @@ class RelPositionMultiheadAttention(nn.Module): # only storing the half-precision output for backprop purposes. attn_output_weights = softmax(attn_output_weights, dim=-1) + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, - head_dim // 2] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) return attn_output, attn_output_weights - def forward2( self, x: Tensor, @@ -1359,7 +1400,7 @@ class RelPositionMultiheadAttention(nn.Module): # now v: (bsz * num_heads, seq_len, head_dim // 2) attn_output = torch.bmm(attn_weights, v) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if random.random() < 0.001 or __name__ == "__main__": self._print_attn_stats(attn_weights, attn_output) @@ -1372,11 +1413,7 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - - def _print_attn_stats( - self, - attn_weights: Tensor, - attn_output: Tensor): + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1387,39 +1424,48 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( - dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) # attn_covar: (num_heads, head_dim, head_dim) - #eigs, _ = torch.symeig(attn_covar) - #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) - out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) - logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") - - + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - def __init__(self, - d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, - initial_scale=0.1, bias=False) - def forward(self, - x: Tensor, - key_padding_mask: Optional[Tensor] = None): + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1429,8 +1475,11 @@ class PoolingModule(nn.Module): a Tensor of shape (1, N, C) """ if key_padding_mask is not None: - pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) + if torch.jit.is_tracing(): + pooling_mask = (~key_padding_mask).to(x.dtype) + else: + pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) + pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1444,24 +1493,19 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. - """ - def __init__(self, - d_model: int, - feedforward_dim: int, - dropout: float): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer(feedforward_dim, - channel_dim=-1, max_abs=10.0, - min_prob=0.25) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, - initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - def forward(self, - x: Tensor): + def forward(self, x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1472,7 +1516,7 @@ class FeedforwardModule(nn.Module): class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. @@ -1481,9 +1525,7 @@ class ConvolutionModule(nn.Module): """ - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1513,7 +1555,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, ) self.depthwise_conv = nn.Conv1d( @@ -1527,8 +1572,10 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, channel_dim=1, - min_positive=0.05, max_positive=1.0, + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, max_abs=20.0, ) @@ -1544,9 +1591,10 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1626,8 +1674,7 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, - channel_dim=1), + ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1636,24 +1683,21 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, - channel_dim=1), + ActivationBalancer(layer2_channels, channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, - channel_dim=1), + ActivationBalancer(layer3_channels, channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) - def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1674,6 +1718,7 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x + class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1717,15 +1762,12 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob - - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1756,28 +1798,35 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) + mask_start = torch.randint( + low=1, + high=int(num_inputs / self.random_prob), + size=(num_frames,), + device=scores.device, + ).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) + arange = ( + torch.arange(num_inputs, device=scores.device) + .unsqueeze(0) + .expand(num_frames, num_inputs) + ) mask = arange >= mask_start - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) + apply_single_prob = torch.logical_and( + torch.rand(size=(num_frames, 1), device=scores.device) + < self.single_prob, + mask_start < num_inputs, + ) + single_prob_mask = torch.logical_and( + apply_single_prob, arange < mask_start - 1 + ) - mask = torch.logical_or(mask, - single_prob_mask) + mask = torch.logical_or(mask, single_prob_mask) - scores = scores.masked_fill(mask, float('-inf')) + scores = scores.masked_fill(mask, float("-inf")) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1792,7 +1841,6 @@ class AttentionCombine(nn.Module): return ans - def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1801,8 +1849,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0) - + single_prob=0.0, + ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1819,7 +1867,10 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), ) batch_size = 5 seq_len = 20 @@ -1828,6 +1879,7 @@ def _test_zipformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) f[0].sum().backward() c.eval() f = c( @@ -1836,19 +1888,18 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings + def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, - dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py new file mode 100755 index 000000000..629bec058 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -0,0 +1,812 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) ctc-decoding +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) 1best +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best + +(3) nbest +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method nbest + +(4) nbest-rescoring +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(5) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py new file mode 100755 index 000000000..7641fa5af --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -0,0 +1,835 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py new file mode 100755 index 000000000..c1607699f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_ctc/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_ctc/decode.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py new file mode 100755 index 000000000..280b95984 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_ctc/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py new file mode 100755 index 000000000..d50d231d5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py new file mode 100644 index 000000000..a6e919e2f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -0,0 +1,198 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return a tuple containing simple loss, pruned loss, and ctc-output. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # compute ctc log-probs + ctc_output = self.ctc_output(encoder_out) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss, ctc_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py new file mode 100755 index 000000000..2f1b1a49f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_ctc/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_ctc/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_ctc/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py new file mode 100755 index 000000000..5d460edb5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -0,0 +1,444 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py new file mode 100755 index 000000000..e482d2040 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_ctc/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + # params.feedforward_dims = "1024,1024,1536,1536,1024" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model_1() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py new file mode 100755 index 000000000..718381baa --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -0,0 +1,1287 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py new file mode 100755 index 000000000..fa7144f0f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -0,0 +1,803 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding +(2) 1best +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best +(3) nbest +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method nbest +(4) nbest-rescoring +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring +(5) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py new file mode 100755 index 000000000..01ba7b711 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -0,0 +1,857 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Yifan Yang,) +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model +from torch.nn.utils.rnn import pad_sequence + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + # filter out blank frames using ctc outputs + ctc_output = model.ctc_output(encoder_out) + encoder_out = model.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(encoder_out_lens), + ) + encoder_out, encoder_out_lens = model.frame_reducer( + x=encoder_out, + x_lens=encoder_out_lens, + ctc_output=ctc_output, + blank_id=0, + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py new file mode 100755 index 000000000..e497787d3 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -0,0 +1,835 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc_bs/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py new file mode 100755 index 000000000..05df8cfff --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_ctc_bs/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_ctc_bs/decode.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py new file mode 100755 index 000000000..630a7f735 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -0,0 +1,897 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# 2023 NVIDIA Corporation (Author: Wen Ding) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to ONNX format + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + - ctc_output.onnx + +(2) Export to ONNX format which can be used in Triton Server +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --onnx-triton 1 + +It will generate the following files in the given `exp_dir`. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - ctc_output.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. + It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + - ctc_output.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, and it exports the model + to onnx format which can be used in NVIDIA triton server. + It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - ctc_output.onnx + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 2000, 80, dtype=torch.float32) + x_lens = torch.tensor([2000] * 15, dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX-Triton format. + The exported model has one input: + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + and has one output: + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + and produces one output: + - logit: a tensor of shape (N, vocab_size) + The exported encoder_proj model has one input: + - encoder_out: a tensor of shape (N, encoder_out_dim) + and produces one output: + - projected_encoder_out: a tensor of shape (N, joiner_dim) + The exported decoder_proj model has one input: + - decoder_out: a tensor of shape (N, decoder_out_dim) + and produces one output: + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + # Note: It uses torch.jit.trace() internally + joiner_model = TritonOnnxJoiner(joiner_model) + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_lconv_onnx( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - src_key_padding_mask: a tensor of shape (N, T) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool) + + torch.onnx.export( + lconv, + (lconv_input, src_key_padding_mask), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "src_key_padding_mask"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + +def export_lconv_onnx_triton( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - lconv_input_lens: a tensor of shape (N, ) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + lconv_input_lens = torch.tensor([498] * 15, dtype=torch.int64) + + lconv = TritonOnnxLconv(lconv) + + torch.onnx.export( + lconv, + (lconv_input, lconv_input_lens), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "lconv_input_lens"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "lconv_input_lens": {0: "N"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + +def export_frame_reducer_onnx( + frame_reducer: nn.Module, + frame_reducer_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has four inputs: + + - x: a tensor of shape (N, T, C) + - x_lens: a tensor of shape (N, T) + - ctc_output: a tensor of shape (N, T, vocab_size) + - blank_id: an int, always 0 + + and has two outputs: + + - x_fr: a tensor of shape (N, T, C) + - x_lens_fr: a tensor of shape (N, T) + + Args: + frame_reducer: + The frame_reducer to be exported. + frame_reducer_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.randn(15, 498, 500, dtype=torch.float32) + + torch.onnx.export( + frame_reducer, + (x, x_lens, ctc_output), + frame_reducer_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "ctc_output"], + output_names=["out", "out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "ctc_output": {0: "N", 1: "T"}, + "out": {0: "N", 1: "T"}, + "out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {frame_reducer_filename}") + + +def export_ctc_output_onnx( + ctc_output: nn.Module, + ctc_output_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has one inputs: + + - encoder_out: a tensor of shape (N, T, C) + + and has one output: + + - ctc_output: a tensor of shape (N, T, vocab_size) + + Args: + ctc_output: + The ctc_output to be exported. + ctc_output_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32) + + torch.onnx.export( + ctc_output, + (encoder_out), + ctc_output_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["ctc_output"], + dynamic_axes={ + "encoder_out": {0: "N", 1: "T"}, + "ctc_output": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {ctc_output_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + if params.onnx is True: + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + if params.onnx is True: + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + lconv_filename = params.exp_dir / "lconv.onnx" + if params.onnx is True: + export_lconv_onnx( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_lconv_onnx_triton( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + + if params.onnx is True: + frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" + export_frame_reducer_onnx( + model.frame_reducer, + frame_reducer_filename, + opset_version=opset_version, + ) + + ctc_output_filename = params.exp_dir / "ctc_output.onnx" + export_ctc_output_onnx( + model.ctc_output, + ctc_output_filename, + opset_version=opset_version, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py new file mode 100644 index 000000000..0841f7cf1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + + +class FrameReducer(nn.Module): + """The encoder output is first used to calculate + the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some thresholds, + it will be simply discarded from the encoder output. + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ctc_output: torch.Tensor, + y_lens: Optional[torch.Tensor] = None, + blank_id: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The shared encoder output with shape [N, T, C]. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + ctc_output: + The CTC output with shape [N, T, vocab_size]. + y_lens: + A tensor of shape (batch_size,) containing the number of frames in + `y` before padding. + blank_id: + The blank id of ctc_output. + Returns: + out: + The frame reduced encoder output with shape [N, T', C]. + out_lens: + A tensor of shape (batch_size,) containing the number of frames in + `out` before padding. + """ + N, T, C = x.size() + + padding_mask = make_pad_mask(x_lens) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + + if y_lens is not None: + # Limit the maximum number of reduced frames + limit_lens = T - y_lens + max_limit_len = limit_lens.max().int() + fake_limit_indexes = torch.topk( + ctc_output[:, :, blank_id], max_limit_len + ).indices + T = ( + torch.arange(max_limit_len) + .expand_as( + fake_limit_indexes, + ) + .to(device=x.device) + ) + T = torch.remainder(T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T) + limit_mask = torch.full_like( + non_blank_mask, + False, + device=x.device, + ).scatter_(1, limit_indexes, True) + + non_blank_mask = non_blank_mask | ~limit_mask + + out_lens = non_blank_mask.sum(dim=1) + max_len = out_lens.max() + pad_lens_list = ( + torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) + - out_lens + ) + max_pad_len = pad_lens_list.max() + + out = F.pad(x, (0, 0, 0, max_pad_len)) + + valid_pad_mask = ~make_pad_mask(pad_lens_list) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + + out = out[total_valid_mask].reshape(N, -1, C) + + return out, out_lens + + +if __name__ == "__main__": + import time + + test_times = 10000 + device = "cuda:0" + frame_reducer = FrameReducer() + + # non zero case + x = torch.ones(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.log( + torch.randn(15, 498, 500, dtype=torch.float32, device=device), + ) + + avg_time = 0 + for i in range(test_times): + torch.cuda.synchronize(device=x.device) + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) + + # all zero case + x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device) + + avg_time = 0 + for i in range(test_times): + torch.cuda.synchronize(device=x.device) + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py new file mode 100755 index 000000000..da2c6a39a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py new file mode 100755 index 000000000..653c25e06 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [ + [i, 0, feature_lengths[i] // params.subsampling_factor] + for i in range(batch_size) + ], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py new file mode 100644 index 000000000..a902358ae --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -0,0 +1,114 @@ +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from scaling import ( + ActivationBalancer, + ScaledConv1d, +) + + +class LConv(nn.Module): + """A convolution module to prevent information loss.""" + + def __init__( + self, + channels: int, + kernel_size: int = 7, + bias: bool = True, + ): + """ + Args: + channels: + Dimension of the input embedding, and of the lconv output. + """ + super().__init__() + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + self.depthwise_conv = nn.Conv1d( + 2 * channels, + 2 * channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=2 * channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + 2 * channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.pointwise_conv2 = ScaledConv1d( + 2 * channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: A 3-D tensor of shape (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(0, 2, 1) # (#batch, channels, time). + + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + + x = self.pointwise_conv2(x) # (batch, channels, time) + + return x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py new file mode 100644 index 000000000..0582b289f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -0,0 +1,226 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos, make_pad_mask + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + lconv: nn.Module, + frame_reducer: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + self.lconv = lconv + self.frame_reducer = frame_reducer + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A floating point value which decides whether to do blank skip. + Returns: + Return a tuple containing simple loss, pruned loss, and ctc-output. + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # compute ctc log-probs + ctc_output = self.ctc_output(encoder_out) + + # y_lens + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + # blank skip + blank_id = self.decoder.blank_id + + if warmup >= 2.0: + # lconv + encoder_out = self.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(x_lens), + ) + + # frame reduce + encoder_out_fr, x_lens_fr = self.frame_reducer( + encoder_out, + x_lens, + ctc_output, + y_lens, + blank_id, + ) + else: + encoder_out_fr = encoder_out + x_lens_fr = x_lens + + # sos_y + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens_fr + + am = self.simple_am_proj(encoder_out_fr) + lm = self.simple_lm_proj(decoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out_fr), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss, ctc_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py new file mode 100755 index 000000000..8ff02fbcb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \ + --lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \ + --frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \ + --ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--lconv-filename", + type=str, + required=True, + help="Path to the lconv onnx model. ", + ) + + parser.add_argument( + "--frame-reducer-filename", + type=str, + required=True, + help="Path to the frame reducer onnx model. ", + ) + + parser.add_argument( + "--ctc-output-filename", + type=str, + required=True, + help="Path to the ctc_output onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + lconv = ort.InferenceSession( + args.lconv_filename, + sess_options=session_opts, + ) + + frame_reducer = ort.InferenceSession( + args.frame_reducer_filename, + sess_options=session_opts, + ) + + ctc_output = ort.InferenceSession( + args.ctc_output_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + ctc_output_input_nodes = ctc_output.get_inputs() + ctc_output_out_nodes = ctc_output.get_outputs() + ctc_out = ctc_output.run( + [ctc_output_out_nodes[0].name], + { + ctc_output_input_nodes[0].name: encoder_out, + }, + )[0] + + lconv_input_nodes = lconv.get_inputs() + lconv_out_nodes = lconv.get_outputs() + encoder_out = lconv.run( + [lconv_out_nodes[0].name], + { + lconv_input_nodes[0].name: encoder_out, + lconv_input_nodes[1] + .name: make_pad_mask(torch.from_numpy(encoder_out_lens)) + .numpy(), + }, + )[0] + + frame_reducer_input_nodes = frame_reducer.get_inputs() + frame_reducer_out_nodes = frame_reducer.get_outputs() + encoder_out_fr, encoder_out_lens_fr = frame_reducer.run( + [frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name], + { + frame_reducer_input_nodes[0].name: encoder_out, + frame_reducer_input_nodes[1].name: encoder_out_lens, + frame_reducer_input_nodes[2].name: ctc_out, + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out_fr, + encoder_out_lens=encoder_out_lens_fr, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py new file mode 100755 index 000000000..247da0949 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from icefall.utils import make_pad_mask + + +class TritonOnnxDecoder(nn.Module): + """ + Triton wrapper for decoder model + """ + + def __init__(self, model): + """ + Args: + model: decoder model + """ + super().__init__() + + self.model = model + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(nn.Module): + def __init__( + self, + model, + ): + super().__init__() + + self.model = model + self.encoder_proj = model.encoder_proj + self.decoder_proj = model.decoder_proj + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, C). + decoder_out: + Output from the decoder. Its shape is (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + project_input = False + return self.model(encoder_out, decoder_out, project_input) + + +class TritonOnnxLconv(nn.Module): + def __init__( + self, + model, + ): + super().__init__() + + self.model = model + + def forward( + self, + lconv_input: torch.Tensor, + lconv_input_lens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + lconv_input: Its shape is (N, T, C). + lconv_input_lens: Its shape is (N, ). + Returns: + Return a tensor of shape (N, T, C). + """ + mask = make_pad_mask(lconv_input_lens) + + return self.model(x=lconv_input, src_key_padding_mask=mask) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py new file mode 100755 index 000000000..ea0fe9164 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_ctc_bs/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_ctc_bs/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py new file mode 100755 index 000000000..412631ba1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py new file mode 100755 index 000000000..7f0893985 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/test_model.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_ctc_bs/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model_1() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py new file mode 100755 index 000000000..ea280e642 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -0,0 +1,1276 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --full-libri 1 \ + --max-duration 300 +# For mix precision training: +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --full-libri 1 \ + --max-duration 750 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from frame_reducer import FrameReducer +from joiner import Joiner +from lconv import LConv +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram, 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.5, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + Explanation of options saved in `params`: + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + - best_train_epoch: It is the epoch that has the best training loss. + - best_valid_epoch: It is the epoch that has the best validation loss. + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - valid_interval: Run validation if batch_idx % valid_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + - subsampling_factor: The subsampling factor for the model. + - encoder_dim: Hidden dim for multi-head attention model. + - num_decoder_layers: Number of decoder layer of transformer decoder. + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_lconv(params: AttributeDict) -> nn.Module: + lconv = LConv( + channels=int(params.encoder_dims.split(",")[-1]), + ) + return lconv + + +def get_frame_reducer(params: AttributeDict) -> nn.Module: + frame_reducer = FrameReducer() + return frame_reducer + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + lconv = get_lconv(params) + frame_reducer = get_frame_reducer(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + lconv=lconv, + frame_reducer=frame_reducer, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + warmup = batch_idx_train / warm_step + + texts = batch["supervisions"]["text"] + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md new file mode 100644 index 000000000..d3691e647 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -0,0 +1,10 @@ +This recipe implements Streaming Zipformer-Transducer model. + +See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. + +[./emformer.py](./emformer.py) and [./train.py](./train.py) +are basically the same as +[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). + diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..3444f8193 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,984 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += 30 + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, 30), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"beam_size_{params.beam_size}_{key}"] = hyps + return ans + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + ngram_lm: + A n-gram LM to be used for LODR. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + import time + + for test_set, test_dl in zip(test_sets, test_dl): + start = time.time() + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 100644 index 000000000..0d7e86fcf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1,151 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + if params.decoding_method == "fast_beam_search": + assert decoding_graph is not None + assert device == decoding_graph.device + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, after subsampling (i.e. a + # cumulative sum of the second return value of + # encoder.streaming_forward + self.done_frames: int = 0 + + # It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2 + # 1) feature embedding: out_lens=(x_lens-7)//2 + # 2) output subsampling: out_lens=(out_lens+1)//2 + self.pad_length = 7 + + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( + decoding_graph + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.params.decoding_method == "greedy_search": + return self.hyp[self.params.context_size :] # noqa + elif self.params.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.params.context_size :] # noqa + else: + assert self.params.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 100755 index 000000000..1f870ca5a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 100755 index 000000000..f5589d1b2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py new file mode 100755 index 000000000..04d97808d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -0,0 +1,678 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 100755 index 000000000..e71bcaf29 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100755 index 000000000..5735ee692 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,878 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + +(3) Export to ONNX format with pretrained.pt + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(4) Export to ONNX format for triton server + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx-triton 1 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + +Check +https://github.com/k2-fsa/sherpa/tree/master/triton +for how to use the exported models outside of icefall. + +""" + + +import argparse +import logging +from pathlib import Path + +import onnxruntime +import sentencepiece as spm +import torch +import torch.nn as nn +from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, --onnx would export model into the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton. + """, + ) + + parser.add_argument( + "--fp16", + action="store_true", + help="whether to export fp16 onnx model, default false", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + for a, b in zip(xlist, blist): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print("small mismatch detected", error) + else: + return False + return True + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + batch_size = 17 + seq_len = 101 + torch.manual_seed(0) + x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32) + x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + initial_states = [encoder_model.get_init_state() for _ in range(batch_size)] + states = stack_states(initial_states) + + left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks + encoder_attention_dim = encoder_model.encoders[0].attention_dim + + len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 + avg_cache = torch.cat( + states[encoder_model.num_encoders : 2 * encoder_model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dim - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders + ] + ] + attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + encoder_model_wrapper = OnnxStreamingEncoder(encoder_model) + + torch.onnx.export( + encoder_model_wrapper, + (x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "x", + "x_lens", + "len_cache", + "avg_cache", + "attn_cache", + "cnn_cache", + ], + output_names=[ + "encoder_out", + "encoder_out_lens", + "new_len_cache", + "new_avg_cache", + "new_attn_cache", + "new_cnn_cache", + ], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "len_cache": {0: "N"}, + "avg_cache": {0: "N"}, + "attn_cache": {0: "N"}, + "cnn_cache": {0: "N"}, + "new_len_cache": {0: "N"}, + "new_avg_cache": {0: "N"}, + "new_attn_cache": {0: "N"}, + "new_cnn_cache": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + # Test onnx encoder with torch native encoder + encoder_model.eval() + ( + encoder_out_torch, + encoder_out_lens_torch, + new_states_torch, + ) = encoder_model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + ort_session = onnxruntime.InferenceSession( + str(encoder_filename), providers=["CPUExecutionProvider"] + ) + ort_inputs = { + "x": x.numpy(), + "x_lens": x_lens.numpy(), + "len_cache": len_cache.numpy(), + "avg_cache": avg_cache.numpy(), + "attn_cache": attn_cache.numpy(), + "cnn_cache": cnn_cache.numpy(), + } + ort_outs = ort_session.run(None, ort_inputs) + + assert test_acc( + [encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2] + ) + logging.info(f"{encoder_filename} acc test succeeded.") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y,), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + and has one output: + - joiner_out: a tensor of shape (N, vocab_size) + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + joiner_model = TritonOnnxJoiner(joiner_model) + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.onnx: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + if not params.onnx_triton: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + else: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + if params.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except ImportError: + print("Please install onnxmltools!") + import sys + + sys.exit(1) + + def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_fp16_filename) + + decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_fp16_filename) + + joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_fp16_filename) + + if not params.onnx_triton: + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + encoder_proj_fp16_filename = ( + params.exp_dir / "joiner_encoder_proj_fp16.onnx" + ) + export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + decoder_proj_fp16_filename = ( + params.exp_dir / "joiner_decoder_proj_fp16.onnx" + ) + export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename) + + elif params.jit: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + model.encoder.__class__.forward = model.encoder.__class__.streaming_forward + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py new file mode 100755 index 000000000..4fd5e1820 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_streaming/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + model.encoder.decode_chunk_size = args.decode_chunk_len // 2 + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100755 index 000000000..a164f3f69 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +""" +Usage: +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100755 index 000000000..f2ac1914d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List, Optional + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..e17d4f734 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 100755 index 000000000..d7a4b9551 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +import torch +from onnx_pretrained import OnnxModel +from zipformer import stack_states + +from icefall import is_module_available + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = [torch_encoder_model.get_init_state() for _ in range(N)] + torch_states = stack_states(torch_states) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 100644 index 000000000..71a418742 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1,231 @@ +from typing import Optional, Tuple + +import torch + + +class OnnxStreamingEncoder(torch.nn.Module): + """This class warps the streaming Zipformer to reduce the number of + state tensors for onnx. + https://github.com/k2-fsa/icefall/pull/831 + """ + + def __init__(self, encoder): + """ + Args: + encoder: An instance of Zipformer Class + """ + super().__init__() + self.model = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + len_cache: torch.tensor, + avg_cache: torch.tensor, + attn_cache: torch.tensor, + cnn_cache: torch.tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + len_cache: + The cached numbers of past frames. + avg_cache: + The cached average tensors. + attn_cache: + The cached key tensors of the first attention modules. + The cached value tensors of the first attention modules. + The cached value tensors of the second attention modules. + cnn_cache: + The cached left contexts of the first convolution modules. + The cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 2 tensors: + + """ + num_encoder_layers = [] + encoder_attention_dims = [] + states = [] + for i, encoder in enumerate(self.model.encoders): + num_encoder_layers.append(encoder.num_layers) + encoder_attention_dims.append(encoder.attention_dim) + + len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B] + offset = 0 + for num_layer in num_encoder_layers: + states.append(len_cache[offset : offset + num_layer]) + offset += num_layer + + avg_cache = avg_cache.transpose(0, 1) # [15, B, 384] + offset = 0 + for num_layer in num_encoder_layers: + states.append(avg_cache[offset : offset + num_layer]) + offset += num_layer + + attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192] + left_context_len = attn_cache.shape[1] + offset = 0 + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[offset : offset + num_layer, : left_context_len // ds] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + encoder_attention_dim = encoder_attention_dims[i] + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + + cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1] + offset = 0 + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + + encoder_out, encoder_out_lens, new_states = self.model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose( + 0, 1 + ) # [B,15] + new_avg_cache = torch.cat( + states[self.model.num_encoders : 2 * self.model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + assert len(set(encoder_attention_dims)) == 1 + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dims[0] - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * self.model.num_encoders : 5 * self.model.num_encoders + ] + ] + new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + return ( + encoder_out, + encoder_out_lens, + new_len_cache, + new_avg_cache, + new_attn_cache, + new_cnn_cache, + ) + + +class TritonOnnxDecoder(torch.nn.Module): + """This class warps the Decoder in decoder.py + to remove the scalar input "need_pad". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + """ + + def __init__( + self, + decoder: torch.nn.Module, + ): + """ + Args: + decoder: A instance of Decoder + """ + super().__init__() + self.model = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + # False to not pad the input. Should be False during inference. + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(torch.nn.Module): + """This class warps the Joiner in joiner.py + to remove the scalar input "project_input". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + "project_input" is set to True. + Triton solutions only need export joiner to a single joiner.onnx. + """ + + def __init__( + self, + joiner: torch.nn.Module, + ): + super().__init__() + self.model = joiner + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + # Apply input projections encoder_proj and decoder_proj. + project_input = True + return self.model(encoder_out, decoder_out, project_input) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 100755 index 000000000..8192e01fd --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1,512 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "zipformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + attention_dims = encoder_meta["attention_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + attention_dims = to_int_list(attention_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"attention_dims: {attention_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + + num_encoders = len(num_encoder_layers) + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + N = batch_size + + for i in range(num_encoders): + cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64)) + cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i])) + cached_key.append( + torch.zeros( + num_encoder_layers[i], left_context_len[i], N, attention_dims[i] + ) + ) + cached_val.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_val2.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_conv1.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + cached_conv2.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + + self.cached_len = cached_len + self.cached_avg = cached_avg + self.cached_key = cached_key + self.cached_val = cached_val + self.cached_val2 = cached_val2 + self.cached_conv1 = cached_conv1 + self.cached_conv2 = cached_conv2 + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_states_input(states: List[torch.Tensor], name: str): + for i, s in enumerate(states): + if isinstance(s, torch.Tensor): + encoder_input[f"{name}_{i}"] = s.numpy() + else: + encoder_input[f"{name}_{i}"] = s + + encoder_output.append(f"new_{name}_{i}") + + build_states_input(self.cached_len, "cached_len") + build_states_input(self.cached_avg, "cached_avg") + build_states_input(self.cached_key, "cached_key") + build_states_input(self.cached_val, "cached_val") + build_states_input(self.cached_val2, "cached_val2") + build_states_input(self.cached_conv1, "cached_conv1") + build_states_input(self.cached_conv2, "cached_conv2") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + num_encoders = self.num_encoders + + self.cached_len = states[num_encoders * 0 : num_encoders * 1] + self.cached_avg = states[num_encoders * 1 : num_encoders * 2] + self.cached_key = states[num_encoders * 2 : num_encoders * 3] + self.cached_val = states[num_encoders * 3 : num_encoders * 4] + self.cached_val2 = states[num_encoders * 4 : num_encoders * 5] + self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6] + self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7] + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100755 index 000000000..fb77fdd42 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 100755 index 000000000..8acace979 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \ + --encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \ + ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav + +You can find pretrained models at +- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 +- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import k2 +import ncnn +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + # Please change the parameters according to your model + + # 20M + # self.num_encoder_layers = to_int_tuple("2,2,2,2,2") + # self.encoder_dims = to_int_tuple("256,256,256,256,256") # also known as d_model + # self.attention_dims = to_int_tuple("192,192,192,192,192") + # self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + # self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + # 9.6M + # self.num_encoder_layers = to_int_tuple("2,3,2,2,3") + # self.encoder_dims = to_int_tuple("160,160,160,160,160") # also known as d_model + # self.attention_dims = to_int_tuple("96,96,96,96,96") + # self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + # self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + # 5.5M or 6M + + # self.num_encoder_layers = to_int_tuple("2,2,2,2,2") + # self.encoder_dims = to_int_tuple("128,128,128,128,128") # also known as d_model + # self.attention_dims = to_int_tuple("96,96,96,96,96") + # self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + # self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + self.num_encoder_layers = to_int_tuple("2,4,3,2,4") + self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model + self.attention_dims = to_int_tuple("192,192,192,192,192") + self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + self.decode_chunk_size = 32 // 2 + num_left_chunks = 4 + self.left_context_length = self.decode_chunk_size * num_left_chunks # 64 + + self.chunk_length = self.decode_chunk_size * 2 + pad_length = 7 + self.T = self.chunk_length + pad_length + + def get_init_states(self) -> List[torch.Tensor]: + cached_len_list = [] + cached_avg_list = [] + cached_key_list = [] + cached_val_list = [] + cached_val2_list = [] + cached_conv1_list = [] + cached_conv2_list = [] + + for i in range(len(self.num_encoder_layers)): + num_layers = self.num_encoder_layers[i] + ds = self.zipformer_downsampling_factors[i] + attention_dim = self.attention_dims[i] + left_context_length = self.left_context_length // ds + encoder_dim = self.encoder_dims[i] + cnn_module_kernel = self.cnn_module_kernels[i] + + cached_len_list.append(torch.zeros(num_layers)) + cached_avg_list.append(torch.zeros(num_layers, encoder_dim)) + cached_key_list.append( + torch.zeros(num_layers, left_context_length, attention_dim) + ) + cached_val_list.append( + torch.zeros(num_layers, left_context_length, attention_dim // 2) + ) + cached_val2_list.append( + torch.zeros(num_layers, left_context_length, attention_dim // 2) + ) + cached_conv1_list.append( + torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) + ) + cached_conv2_list.append( + torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) + ) + + states = ( + cached_len_list + + cached_avg_list + + cached_key_list + + cached_val_list + + cached_val2_list + + cached_conv1_list + + cached_conv2_list + ) + + return states + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.num_threads = 4 + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.num_threads = 4 + + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + A tensor of shape (T, C) + states: + A list of tensors. len(states) == self.num_layers * 4 + Returns: + Return a tuple containing: + - encoder_out, a tensor of shape (T, encoder_dim). + - next_states, a list of tensors containing the next states + """ + with self.encoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + for i in range(len(states)): + name = f"in{i+1}" + ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + + out_states: List[torch.Tensor] = [] + for i in range(len(states)): + name = f"out{i+1}" + ret, ncnn_out_state = ex.extract(name) + assert ret == 0, ret + ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy()) + + if i < len(self.num_encoder_layers): + # for cached_len, we need to discard the last dim + ncnn_out_state = ncnn_out_state.squeeze(1) + + out_states.append(ncnn_out_state) + + return encoder_out, out_states + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t] + + joiner_out = model.run_joiner(cur_encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + states = model.get_init_states() + logging.info(f"number of states: {len(states)}") + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = model.T + offset = model.chunk_length + + chunk = int(1 * sample_rate) # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, states = model.run_encoder(frames, states) + hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(sound_file) + logging.info(text) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..3a5f89833 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..c272ed641 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..de12d7af1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_small(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,2,2,2,2" + params.feedforward_dims = "256,256,512,512,256" + params.nhead = "4,4,4,4,4" + params.encoder_dims = "128,128,128,128,128" + params.attention_dims = "96,96,96,96,96" + params.encoder_unmasked_dims = "96,96,96,96,96" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 320 + params.joiner_dim = 320 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + import pdb + + pdb.set_trace() + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model_small() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..b2f9ffc09 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1264 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + # train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..5437e961e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1265 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 100644 index 000000000..a5c422959 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1,2891 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + Identity, + MaxEig, + ScaledConv1d, + Whiten, + _diag, + penalize_abs_values_gt, + random_clamp, + softmax, +) +from torch import Tensor, nn + +from icefall.utils import make_pad_mask, subsequent_chunk_mask + + +def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + ``states[i][0:num_encoders]`` is the cached numbers of past frames. + ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + assert len(state_list[0]) % 7 == 0, len(state_list[0]) + num_encoders = len(state_list[0]) // 7 + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + # For cached_len + len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] + for i in range(num_encoders): + # len_avg: (num_layers, batch_size) + len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) + cached_len.append(len_avg) + + # For cached_avg + avg_list = [ + state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # avg: (num_layers, batch_size, D) + avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) + cached_avg.append(avg) + + # For cached_key + key_list = [ + state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # key: (num_layers, left_context_size, batch_size, D) + key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) + cached_key.append(key) + + # For cached_val + val_list = [ + state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val: (num_layers, left_context_size, batch_size, D) + val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) + cached_val.append(val) + + # For cached_val2 + val2_list = [ + state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val2: (num_layers, left_context_size, batch_size, D) + val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) + cached_val2.append(val2) + + # For cached_conv1 + conv1_list = [ + state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv1: (num_layers, batch_size, D, kernel-1) + conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) + cached_conv1.append(conv1) + + # For cached_conv2 + conv2_list = [ + state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv2: (num_layers, batch_size, D, kernel-1) + conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A list of states. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + """ + assert len(states) % 7 == 0, len(states) + num_encoders = len(states) // 7 + ( + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) + + batch_size = cached_len[0].shape[1] + + len_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_len[i]: (num_layers, batch_size) + len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + len_list[n].append(len_avg[n]) + + avg_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_avg[i]: (num_layers, batch_size, D) + avg = cached_avg[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + avg_list[n].append(avg[n]) + + key_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_key[i]: (num_layers, left_context, batch_size, D) + key = cached_key[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + key_list[n].append(key[n]) + + val_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val[i]: (num_layers, left_context, batch_size, D) + val = cached_val[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val_list[n].append(val[n]) + + val2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val2[i]: (num_layers, left_context, batch_size, D) + val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val2_list[n].append(val2[n]) + + conv1_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) + conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv1_list[n].append(conv1[n]) + + conv2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) + conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv2_list[n].append(conv2[n]) + + state_list = [ + ( + len_list[i] + + avg_list[i] + + key_list[i] + + val_list[i] + + val2_list[i] + + conv1_list[i] + + conv2_list[i] + ) + for i in range(batch_size) + ] + return state_list + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernels (int): Kernel size of convolution module + warmup_batches (float): number of batches to warm up over + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + num_left_chunks: int = 4, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 50, + decode_chunk_size: int = 16, + warmup_batches: float = 4000.0, + ) -> None: + super(Zipformer, self).__init__() + + self.num_features = num_features + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + self.num_left_chunks = num_left_chunks + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + + # Used in decoding + self.decode_chunk_size = decode_chunk_size + + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u, d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout + ) + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + self.num_encoder_layers = num_encoder_layers + self.num_encoders = len(encoder_dims) + self.attention_dims = attention_dim + self.cnn_module_kernels = cnn_module_kernels + for i in range(self.num_encoders): + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), + ) + + if zipformer_downsampling_factors[i] != 1: + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample( + encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor + ) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, + we combine the outputs of layers 1 and 4. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i - 1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i - 2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) + skip_layers.append(j) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks(self, x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all encoder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoder dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_downsampling_factors times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = num_frames0 + max_downsampling_factor - 1 + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = max_downsampling_factor // ds + + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + chunk_size: + The chunk size used in evaluation mode. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + if self.training: + # Training mode + max_ds = max(self.zipformer_downsampling_factors) + # Generate dynamic chunk-wise attention mask during training + max_len = x.size(0) // max_ds + short_chunk_size = self.short_chunk_size // max_ds + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + # Full attention + chunk_size = x.size(0) + else: + # Chunk-wise attention + chunk_size = chunk_size % short_chunk_size + 1 + chunk_size *= max_ds + else: + chunk_size = self.decode_chunk_size + # Evaluation mode + for ds in self.zipformer_downsampling_factors: + assert chunk_size % ds == 0, (chunk_size, ds) + + attn_mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + attn_mask=attn_mask[::ds, ::ds], + ) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: List[Tensor], + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + seq_len is the input chunk length. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 3 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states. + """ + assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) + + cached_len = states[: self.num_encoders] + cached_avg = states[self.num_encoders : 2 * self.num_encoders] + cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] + cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] + cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] + cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] + cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] + + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + + outputs = [] + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + k = self.skip_layers[i] + if isinstance(k, int): + x = skip_module(outputs[k], x) + x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( + x, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + outputs.append(x) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = ( + new_cached_len + + new_cached_avg + + new_cached_key + + new_cached_val + + new_cached_val2 + + new_cached_conv1 + + new_cached_conv2 + ) + return x, lengths, new_states + + @torch.jit.export + def get_init_state( + self, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + """ + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + left_context_len = self.decode_chunk_size * self.num_left_chunks + + for i, encoder in enumerate(self.encoders): + num_layers = encoder.num_layers + ds = self.zipformer_downsampling_factors[i] + + len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) + cached_len.append(len_avg) + + avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) + cached_avg.append(avg) + + key = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim, + device=device, + ) + cached_key.append(key) + + val = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val.append(val) + + val2 = torch.zeros( + num_layers, + left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val2.append(val2) + + conv1 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv1.append(conv1) + + conv2 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + self.attention_dim = attention_dim + self.cnn_module_kernel = cnn_module_kernel + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + elif random.random() >= dynamic_dropout: + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() >= dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() >= dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() >= dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + cached_len: processed number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor of left context for the first attention module. + cached_val: cached value tensor of left context for the first attention module. + cached_val2: cached value tensor of left context for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + pos_emb: (N, left_context_len+2*S-1, E) + cached_len: (N,) + N is the batch size. + cached_avg: (N, C). + N is the batch size, C is the feature dimension. + cached_key: (left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + src_pool, cached_len, cached_avg = self.pooling.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + ) + src = src + src_pool + + ( + src_attn, + attn_weights, + cached_key, + cached_val, + ) = self.self_attn.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + cached_val=cached_val, + ) + src = src + src_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + src_attn, cached_val2 = self.self_attn.streaming_forward2( + src, + attn_weights, + cached_val=cached_val2, + ) + src = src + src_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.bypass_scale + + return ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.d_model = encoder_layer.d_model + self.attention_dim = encoder_layer.attention_dim + self.cnn_module_kernel = encoder_layer.cnn_module_kernel + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) + return ans + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + @torch.jit.export + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + cached_len: number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor for first attention module. + cached_val: cached value tensor for first attention module. + cached_val2: cached value tensor for second attention module. + cached_conv1: cached left contexts for the first convolution module. + cached_conv2: cached left contexts for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (num_layers,) + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + + Returns: A tuple of 8 tensors: + - output tensor + - updated cached number of past frames. + - updated cached average of past frames. + - updated cached key tensor of of the first attention module. + - updated cached value tensor of of the first attention module. + - updated cached value tensor of of the second attention module. + - updated cached left contexts of the first convolution module. + - updated cached left contexts of the second convolution module. + """ + assert cached_len.size(0) == self.num_layers, ( + cached_len.size(0), + self.num_layers, + ) + assert cached_avg.size(0) == self.num_layers, ( + cached_avg.size(0), + self.num_layers, + ) + assert cached_key.size(0) == self.num_layers, ( + cached_key.size(0), + self.num_layers, + ) + assert cached_val.size(0) == self.num_layers, ( + cached_val.size(0), + self.num_layers, + ) + assert cached_val2.size(0) == self.num_layers, ( + cached_val2.size(0), + self.num_layers, + ) + assert cached_conv1.size(0) == self.num_layers, ( + cached_conv1.size(0), + self.num_layers, + ) + assert cached_conv2.size(0) == self.num_layers, ( + cached_conv2.size(0), + self.num_layers, + ) + + left_context_len = cached_key.shape[1] + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + for i, mod in enumerate(self.layers): + output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( + output, + pos_emb, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + return ( + output, + torch.stack(new_cached_len, dim=0), + torch.stack(new_cached_avg, dim=0), + torch.stack(new_cached_key, dim=0), + torch.stack(new_cached_val, dim=0), + torch.stack(new_cached_val2, dim=0), + torch.stack(new_cached_conv1, dim=0), + torch.stack(new_cached_conv2, dim=0), + ) + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int + ): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.encoder = encoder + self.num_layers = encoder.num_layers + self.d_model = encoder.d_model + self.attention_dim = encoder.attention_dim + self.cnn_module_kernel = encoder.cnn_module_kernel + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + attn_mask: attention mask (optional). Should be downsampled already. + src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. + + Shape: + src: (S, N, E). + attn_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + src = self.encoder( + src, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + cached_avg: cached average value of past frames. + cached_len: length of past frames. + cached_key: cached key tensor for the first attention module. + cached_val: cached value tensor for the first attention module. + cached_val2: cached value tensor for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = self.encoder.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + cached_key=cached_key, + cached_val=cached_val, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return ( + self.out_combiner(src_orig, src), + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, in_channels: int, out_channels: int, downsample: int): + super(AttentionDownsample, self).__init__() + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) + else: + self.extra_proj = None + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, 1, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + src1_dim = src1.shape[-1] + src2_dim = src2.shape[-1] + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + return src1 + src2 + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + x_size_left = x.size(0) + left_context_len + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_left * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). + + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_left + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = ( + 2 * attention_dim # query, key + + attention_dim // 2 # value + + pos_dim * num_heads # positional encoding query + ) + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. + - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. + + - Returns: (attn_output, attn_weights, cached_key, cached_val) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of + left context + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of + """ + ( + x, + weights, + cached_key, + cached_val, + ) = self.streaming_multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.out_proj.weight, + self.out_proj.bias, + cached_key=cached_key, + cached_val=cached_val, + ) + return x, weights, cached_key, cached_val + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) + else: + attn_output_weights = attn_output_weights + attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights + + def streaming_multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + out_proj_weight, out_proj_bias: the output projection weight and bias. + cached_key: cached attention key tensor of left context. + cached_val: cached attention value tensor of left context. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + left_context_len = cached_key.shape[0] + assert left_context_len > 0, left_context_len + assert cached_key.shape[0] == cached_val.shape[0], ( + cached_key.shape, + cached_val.shape, + ) + # Pad cached left contexts + k = torch.cat([cached_key, k], dim=0) + v = torch.cat([cached_val, v], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + cached_val = v[-left_context_len:, ...] + + # The length of key and value + kv_len = k.shape[0] + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(kv_len, bsz, num_heads, head_dim) + v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + left_context_len + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (batch_size, num_heads, time1, n) = pos_weights.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(kv_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_weights = pos_weights.reshape(-1, n) + pos_weights = torch.gather(pos_weights, dim=1, index=indexes) + pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len) + else: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights, cached_key, cached_val + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + def streaming_forward2( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + cached_val: cached attention value tensor of left context. + Returns: + - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + - updated cached attention value tensor of left context. + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + + left_context_len = cached_val.shape[0] + assert left_context_len > 0, left_context_len + v = torch.cat([cached_val, v], dim=0) + cached_val = v[-left_context_len:] + + seq_len2 = left_context_len + seq_len + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output), cached_val + + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) + # attn_covar: (num_heads, head_dim, head_dim) + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) + + +class PoolingModule(nn.Module): + """ + Averages the input over the time dimension and project with a square matrix. + """ + + def __init__(self, d_model: int): + super().__init__() + self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: a Tensor of shape (T, N, C) + src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked + positions. + + Returns: + - output, a Tensor of shape (T, N, C). + """ + if src_key_padding_mask is not None: + # False in padding positions + padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) + # Cumulated numbers of frames from start + cum_mask = padding_mask.cumsum(dim=1) # (N, T) + x = x.cumsum(dim=0) # (T, N, C) + pooling_mask = padding_mask / cum_mask + pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + else: + num_frames = x.shape[0] + cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) + x = x.cumsum(dim=0) # (T, N, C) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask + + x = self.proj(x) + return x + + def streaming_forward( + self, + x: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + x: a Tensor of shape (T, N, C) + cached_len: a Tensor of int, of shape (N,), containing the number of + past frames in batch. + cached_avg: a Tensor of shape (N, C), the average over all past frames + in batch. + + Returns: + A tuple of 2 tensors: + - output, a Tensor of shape (T, N, C). + - updated cached_avg, a Tensor of shape (N, C). + """ + x = x.cumsum(dim=0) # (T, N, C) + x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) + # Cumulated numbers of frames from start + cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) + cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + + cached_len = cached_len + x.size(0) + cached_avg = x[-1] + + x = self.proj(x) + return x, cached_len, cached_avg + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + # Will pad cached left context + self.lorder = kernel_size - 1 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch: + (batch, #time), contains bool in masked positions. + cache: Cached left context for depthwise_conv, with shape of + (batch, channels, #kernel_size-1). Only used in real streaming decoding. + + Returns: + A tuple of 2 tensors: + - Output tensor (#time, batch, channels). + - New cached left context, with shape of (batch, channels, #kernel_size-1). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( + cache.shape, + (x.size(0), x.size(1), self.lorder), + ) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[:, :, -self.lorder :] + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ActivationBalancer(layer1_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + ActivationBalancer(layer2_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + ActivationBalancer(layer3_channels, channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 47 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + decode_chunk_size=4, + ) + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + +def _test_pooling_module(): + N, S, C = 2, 12, 32 + chunk_len = 4 + m = PoolingModule(d_model=C) + + # test chunk-wise forward with padding_mask + x = torch.randn(S, N, C) + y = m(x) + cached_len = torch.zeros(N, dtype=torch.int32) + cached_avg = torch.zeros(N, C) + for i in range(S // chunk_len): + start = i * chunk_len + end = start + chunk_len + x_chunk = x[start:end] + y_chunk, cached_len, cached_avg = m.streaming_forward( + x_chunk, + cached_len=cached_len, + cached_avg=cached_avg, + ) + assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) + + +def _test_state_stack_unstack(): + m = Zipformer( + num_features=80, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + zipformer_downsampling_factors=(4, 8), + num_left_chunks=2, + decode_chunk_size=8, + ) + s1 = m.get_init_state() + s2 = m.get_init_state() + states = stack_states([s1, s2]) + new_s1, new_s2 = unstack_states(states) + for i in range(m.num_encoders * 7): + for x, y in zip(s1[i], new_s1[i]): + assert torch.equal(x, y) + for x, y in zip(s2[i], new_s2[i]): + assert torch.equal(x, y) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main() + _test_conv2d_subsampling() + _test_pooling_module() + _test_state_stack_unstack() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 100644 index 000000000..be9cd1608 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1,3144 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( # not as in other dirs.. just scales down initial parameter values. + ActivationBalancer, + BasicNorm, + DoubleSwish, + Identity, + MaxEig, + ScaledConv1d, + ScaledLinear, + Whiten, + _diag, + penalize_abs_values_gt, + random_clamp, + softmax, +) +from torch import Tensor, nn +from zipformer import PoolingModule + +from icefall.utils import make_pad_mask, subsequent_chunk_mask + + +def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + ``states[i][0:num_encoders]`` is the cached numbers of past frames. + ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + assert len(state_list[0]) % 7 == 0, len(state_list[0]) + num_encoders = len(state_list[0]) // 7 + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + # For cached_len + len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] + for i in range(num_encoders): + # len_avg: (num_layers, batch_size) + len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) + cached_len.append(len_avg) + + # For cached_avg + avg_list = [ + state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # avg: (num_layers, batch_size, D) + avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) + cached_avg.append(avg) + + # For cached_key + key_list = [ + state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # key: (num_layers, left_context_size, batch_size, D) + key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) + cached_key.append(key) + + # For cached_val + val_list = [ + state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val: (num_layers, left_context_size, batch_size, D) + val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) + cached_val.append(val) + + # For cached_val2 + val2_list = [ + state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val2: (num_layers, left_context_size, batch_size, D) + val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) + cached_val2.append(val2) + + # For cached_conv1 + conv1_list = [ + state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv1: (num_layers, batch_size, D, kernel-1) + conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) + cached_conv1.append(conv1) + + # For cached_conv2 + conv2_list = [ + state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv2: (num_layers, batch_size, D, kernel-1) + conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A list of states. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + """ + assert len(states) % 7 == 0, len(states) + num_encoders = len(states) // 7 + ( + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) + + batch_size = cached_len[0].shape[1] + + len_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_len[i]: (num_layers, batch_size) + len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + len_list[n].append(len_avg[n]) + + avg_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_avg[i]: (num_layers, batch_size, D) + avg = cached_avg[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + avg_list[n].append(avg[n]) + + key_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_key[i]: (num_layers, left_context, batch_size, D) + key = cached_key[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + key_list[n].append(key[n]) + + val_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val[i]: (num_layers, left_context, batch_size, D) + val = cached_val[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val_list[n].append(val[n]) + + val2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val2[i]: (num_layers, left_context, batch_size, D) + val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val2_list[n].append(val2[n]) + + conv1_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) + conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv1_list[n].append(conv1[n]) + + conv2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) + conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv2_list[n].append(conv2[n]) + + state_list = [ + ( + len_list[i] + + avg_list[i] + + key_list[i] + + val_list[i] + + val2_list[i] + + conv1_list[i] + + conv2_list[i] + ) + for i in range(batch_size) + ] + return state_list + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernels (int): Kernel size of convolution module + warmup_batches (float): number of batches to warm up over + is_pnnx (bool): True if we are going to convert this model via pnnx. + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + num_left_chunks: int = 4, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 50, + decode_chunk_size: int = 16, + warmup_batches: float = 4000.0, + is_pnnx: bool = False, + ) -> None: + super(Zipformer, self).__init__() + self.is_pnnx = is_pnnx + + self.num_features = num_features + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + self.num_left_chunks = num_left_chunks + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + + # Used in decoding + self.decode_chunk_size = decode_chunk_size + + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u, d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout, is_pnnx=is_pnnx + ) + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + self.num_encoders = len(encoder_dims) + for i in range(self.num_encoders): + ds = zipformer_downsampling_factors[i] + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + is_pnnx=self.is_pnnx, + left_context_len=self.left_context_len // ds, + x_size=self.decode_chunk_size // ds, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), + is_pnnx=is_pnnx, + left_context_len=self.left_context_len // ds, + x_size=self.decode_chunk_size // ds, + ) + + if zipformer_downsampling_factors[i] != 1: + in_x_size = self.decode_chunk_size + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + is_pnnx=is_pnnx, + left_context_len=self.left_context_len // ds, + in_x_size=in_x_size, + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample( + encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor, + is_pnnx=is_pnnx, + in_x_size=self.decode_chunk_size, + ) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, + we combine the outputs of layers 1 and 4. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i - 1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i - 2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) + skip_layers.append(j) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks(self, x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all encoder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoder dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_downsampling_factors times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = num_frames0 + max_downsampling_factor - 1 + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = max_downsampling_factor // ds + + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + chunk_size: + The chunk size used in evaluation mode. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + if self.training: + # Training mode + max_ds = max(self.zipformer_downsampling_factors) + # Generate dynamic chunk-wise attention mask during training + max_len = x.size(0) // max_ds + short_chunk_size = self.short_chunk_size // max_ds + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + # Full attention + chunk_size = x.size(0) + else: + # Chunk-wise attention + chunk_size = chunk_size % short_chunk_size + 1 + chunk_size *= max_ds + else: + chunk_size = self.decode_chunk_size + # Evaluation mode + for ds in self.zipformer_downsampling_factors: + assert chunk_size % ds == 0, (chunk_size, ds) + + attn_mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + attn_mask=attn_mask[::ds, ::ds], + ) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + def streaming_forward( + self, + x: torch.Tensor, + states: List[Tensor], + ) -> Tuple[Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + seq_len is the input chunk length. + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 3 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states. + """ + assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) + + cached_len = states[: self.num_encoders] + cached_avg = states[self.num_encoders : 2 * self.num_encoders] + cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] + cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] + cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] + cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] + cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] + + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + outputs = [] + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + k = self.skip_layers[i] + if isinstance(k, int): + x = skip_module(outputs[k], x) + x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( + x, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + + outputs.append(x) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = ( + new_cached_len + + new_cached_avg + + new_cached_key + + new_cached_val + + new_cached_val2 + + new_cached_conv1 + + new_cached_conv2 + ) + return x, new_states + + @torch.jit.export + def get_init_state( + self, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + """ + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + for i, encoder in enumerate(self.encoders): + num_layers = encoder.num_layers + ds = self.zipformer_downsampling_factors[i] + + len_avg = torch.zeros(num_layers, 1, device=device) + cached_len.append(len_avg) + + avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) + cached_avg.append(avg) + + key = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim, + device=device, + ) + cached_key.append(key) + + val = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val.append(val) + + val2 = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val2.append(val2) + + conv1 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv1.append(conv1) + + conv2 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + is_pnnx: bool = False, + left_context_len: int = 0, + x_size: int = 0, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + self.attention_dim = attention_dim + self.cnn_module_kernel = cnn_module_kernel + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, + is_pnnx=is_pnnx, + left_context_len=left_context_len, + x_size=x_size, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.conv_module1 = ConvolutionModule( + d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size + ) + + self.conv_module2 = ConvolutionModule( + d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size + ) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + elif random.random() >= dynamic_dropout: + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() >= dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() >= dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() >= dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + cached_len: processed number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor of left context for the first attention module. + cached_val: cached value tensor of left context for the first attention module. + cached_val2: cached value tensor of left context for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + pos_emb: (N, left_context_len+2*S-1, E) + cached_len: (N,) + N is the batch size. + cached_avg: (N, C). + N is the batch size, C is the feature dimension. + cached_key: (left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + src_pool, cached_len, cached_avg = self.pooling.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + ) + src = src + src_pool + + ( + src_attn, + attn_weights, + cached_key, + cached_val, + ) = self.self_attn.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + cached_val=cached_val, + ) + + src = src + src_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + ) + + src = src + src_conv + + src = src + self.feed_forward2(src) + + src_attn, cached_val2 = self.self_attn.streaming_forward2( + src, + attn_weights, + cached_val=cached_val2, + ) + src = src + src_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.bypass_scale + + return ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class ZipformerStateSelect(nn.Module): + """ncnn does not support selecting along batch index. + This class provides a workaround for it. We + need to change pnnx accordingly. + """ + + def __init__(self, i: int): + super().__init__() + self.i = i + + def forward(self, x: torch.Tensor): + return x[self.i] + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + is_pnnx: bool = False, + x_size: int = 0, + left_context_len: int = 0, + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + self.left_context_len = left_context_len + + self.encoder_pos = RelPositionalEncoding( + encoder_layer.d_model, + dropout, + is_pnnx=is_pnnx, + x_size=x_size, + left_context_len=left_context_len, + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + state_select_list = [] + for i in range(num_layers): + state_select_list.append(ZipformerStateSelect(i)) + self.state_select_list = nn.ModuleList(state_select_list) + + self.d_model = encoder_layer.d_model + self.attention_dim = encoder_layer.attention_dim + self.cnn_module_kernel = encoder_layer.cnn_module_kernel + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) + return ans + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + @torch.jit.export + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + cached_len: number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor for first attention module. + cached_val: cached value tensor for first attention module. + cached_val2: cached value tensor for second attention module. + cached_conv1: cached left contexts for the first convolution module. + cached_conv2: cached left contexts for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (num_layers,) + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + + Returns: A tuple of 8 tensors: + - output tensor + - updated cached number of past frames. + - updated cached average of past frames. + - updated cached key tensor of of the first attention module. + - updated cached value tensor of of the first attention module. + - updated cached value tensor of of the second attention module. + - updated cached left contexts of the first convolution module. + - updated cached left contexts of the second convolution module. + """ + assert cached_len.size(0) == self.num_layers, ( + cached_len.size(0), + self.num_layers, + ) + assert cached_avg.size(0) == self.num_layers, ( + cached_avg.size(0), + self.num_layers, + ) + assert cached_key.size(0) == self.num_layers, ( + cached_key.size(0), + self.num_layers, + ) + assert cached_val.size(0) == self.num_layers, ( + cached_val.size(0), + self.num_layers, + ) + assert cached_val2.size(0) == self.num_layers, ( + cached_val2.size(0), + self.num_layers, + ) + assert cached_conv1.size(0) == self.num_layers, ( + cached_conv1.size(0), + self.num_layers, + ) + assert cached_conv2.size(0) == self.num_layers, ( + cached_conv2.size(0), + self.num_layers, + ) + + assert self.left_context_len == cached_key.shape[1], ( + self.left_context_len, + cached_key.shape[1], + ) + + left_context_len = self.left_context_len + pos_emb = self.encoder_pos(src, left_context_len) + + output = src + + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + for i, (mod, state_select) in enumerate( + zip(self.layers, self.state_select_list) + ): + output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( + output, + pos_emb, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=state_select(cached_conv1), + cached_conv2=state_select(cached_conv2), + ) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + return ( + output, + torch.stack(new_cached_len, dim=0), + torch.stack(new_cached_avg, dim=0), + torch.stack(new_cached_key, dim=0), + torch.stack(new_cached_val, dim=0), + torch.stack(new_cached_val2, dim=0), + torch.stack(new_cached_conv1, dim=0), + torch.stack(new_cached_conv2, dim=0), + ) + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int, + is_pnnx: bool = False, + left_context_len: int = 0, + in_x_size: int = 0, + ): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample( + input_dim, output_dim, downsample, is_pnnx=is_pnnx, in_x_size=in_x_size + ) + self.encoder = encoder + self.num_layers = encoder.num_layers + self.d_model = encoder.d_model + self.attention_dim = encoder.attention_dim + self.cnn_module_kernel = encoder.cnn_module_kernel + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + self.in_x_size = in_x_size + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + attn_mask: attention mask (optional). Should be downsampled already. + src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. + + Shape: + src: (S, N, E). + attn_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + src = self.encoder( + src, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + cached_avg: cached average value of past frames. + cached_len: length of past frames. + cached_key: cached key tensor for the first attention module. + cached_val: cached value tensor for the first attention module. + cached_val2: cached value tensor for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) + + src_orig = src + + src = self.downsample(src) + + ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = self.encoder.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + cached_key=cached_key, + cached_val=cached_val, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + ) + + src = self.upsample(src) + + if src.shape[0] != self.in_x_size: + # remove any extra frames that are not a multiple of downsample_factor + src = src[: self.in_x_size] + + return ( + self.out_combiner(src_orig, src), + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class AttentionDownsampleUnsqueeze(torch.nn.Module): + """We apply this operation only in PyTorch + and discards in ncnn. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(1) + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + downsample: int, + is_pnnx: bool = False, + in_x_size: int = 0, + ): + super(AttentionDownsample, self).__init__() + + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_pnnx = is_pnnx + self.in_x_size = in_x_size + + self.unsqueeze = AttentionDownsampleUnsqueeze() + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) + else: + self.extra_proj = None + self.downsample = downsample + + self.d_seq_len = (in_x_size + downsample - 1) // downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, 1, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + assert src.shape[0] == self.in_x_size, ( + src.shape[0], + self.in_x_size, + src.shape, + type(src), + ) + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) + if not self.is_pnnx: + (seq_len, batch_size, in_channels) = src.shape + else: + seq_len = self.in_x_size + batch_size = 1 + in_channels = self.in_channels + + ds = self.downsample + d_seq_len = self.d_seq_len + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + assert self.is_pnnx is False, "TODO(fangjun): Handle it!" + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + if not self.is_pnnx: + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape( + d_seq_len, batch_size, ds * in_channels + ) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + else: + src = src.reshape(d_seq_len, ds, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + assert ( + self.extra_proj is None + ), "The code for it being not None is not tested" + # ans = ans.unsqueeze(1) + ans = self.unsqueeze(ans) + # Note: In ncnn, we ignore self.unsqueeze + # so ans in ncnn is still a 2-D tensor, e.g., (8, 384) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + self.upsample = upsample + self.num_channels = num_channels + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) + + src1_dim = self.dim1 + src2_dim = self.dim2 + + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + return src1 + src2 + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + is_pnnx: bool = False, + x_size: int = 0, + left_context_len: int = 0, + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.is_pnnx = is_pnnx + self.x_size = x_size + self.left_context_len = left_context_len + self.pe = None + if is_pnnx: + x_size_left = x_size + left_context_len + self.extend_pe(torch.tensor(0.0).expand(x_size_left)) + self.pe = self.pe[:, :-left_context_len] + assert self.pe.size(1) == x_size + left_context_len - 1 + x_size, ( + self.pe.size(1), + x_size, + left_context_len, + x_size, + self.pe.shape, + ) + else: + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + x_size_left = x.size(0) + left_context_len + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_left * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). + + """ + if self.is_pnnx: + assert self.x_size == x.size(0), (self.x_size, x.size(0)) + assert self.left_context_len == left_context_len, ( + self.left_context_len, + left_context_len, + ) + return self.pe + + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_left + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionPermute(nn.Module): + """ncnn does not support permuatation relating to the batch axis 0. + This is a workaround for exporting to ncnn via PNNX. + """ + + def __init__(self, kind: int): + super().__init__() + self.kind = kind + assert self.kind in (2, 3), self.kind + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.kind == 2: + return x.permute(1, 0, 2) + elif self.kind == 3: + return x.permute(1, 2, 0) + else: + assert False, f"Unsupported kind {self.kind}" + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + is_pnnx: bool = False, + left_context_len: int = 0, + x_size: int = 0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + self.is_pnnx = is_pnnx + + self.my_permute_pqv = RelPositionMultiheadAttentionPermute(kind=2) + self.my_permute_k_pos = RelPositionMultiheadAttentionPermute(kind=3) + self.left_context_len = left_context_len + self.x_size = x_size + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query (attention_dim,), key (attention_dim,) + + pos_dim * num_heads # value (attention_dim // 2,) + ) # positional encoding query (pos_dim * num_heads, ) + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. + - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. + + - Returns: (attn_output, attn_weights, cached_key, cached_val) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of + left context + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of + """ + ( + x, + weights, + cached_key, + cached_val, + ) = self.streaming_multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.out_proj.weight, + self.out_proj.bias, + cached_key=cached_key, + cached_val=cached_val, + ) + return x, weights, cached_key, cached_val + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) + else: + attn_output_weights = attn_output_weights + attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights + + def streaming_multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + out_proj_weight, out_proj_bias: the output projection weight and bias. + cached_key: cached attention key tensor of left context. + cached_val: cached attention value tensor of left context. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. + """ + if not self.is_pnnx: + seq_len, bsz, _ = x_proj.size() + assert seq_len == self.x_size, (seq_len, self.x_size) + else: + seq_len = self.x_size + bsz = 1 + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[:, :, 0:attention_dim] # (x_size, N, attention_dim) + # return q, q, q, q + k = x_proj[:, :, attention_dim : 2 * attention_dim] + # k is (x_size, N, attention_dim) + value_dim = attention_dim // 2 + v = x_proj[:, :, 2 * attention_dim : 2 * attention_dim + value_dim] + # v is (x_size, 0, attention_dim//2) + + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[:, :, 2 * attention_dim + value_dim :] + # p is (x_size, N, pos_dim * num_heads) + + if not self.is_pnnx: + left_context_len = cached_key.shape[0] + else: + assert cached_key.shape[0] == self.left_context_len, ( + cached_key.shape, + self.left_context_len, + ) + left_context_len = self.left_context_len + + assert left_context_len > 0, left_context_len + assert cached_key.shape[0] == cached_val.shape[0], ( + cached_key.shape, + cached_val.shape, + ) + # Note: We need to fix the Concat in ncnn + # cached_key is (1, 64, 192) in ncnn + # k is (16, 192) in ncnn + # Pad cached left contexts + k = torch.cat([cached_key, k], dim=0) + # (left_context_len + x_size, N, attention_dim) + + v = torch.cat([cached_val, v], dim=0) + # v: (left_context_len + x_size, N, attention_dim//2) + # Update cached left contexts + if not self.is_pnnx: + cached_key = k[-left_context_len:, ...] + cached_val = v[-left_context_len:, ...] + else: + cached_key = k[self.x_size :] + cached_val = v[self.x_size :] + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape, + left_context_len, + ) + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape, + left_context_len, + ) + + if not self.is_pnnx: + # The length of key and value + kv_len = k.shape[0] + else: + kv_len = left_context_len + self.x_size + assert kv_len == k.shape[0], (kv_len, k.shape) + + if not self.is_pnnx: + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(kv_len, bsz, num_heads, head_dim) + + v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + # v is (bsz * num_heads, kv_len, head_dim//2) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + left_context_len + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + else: + q = q.reshape(seq_len, num_heads, head_dim) + p = p.reshape(seq_len, num_heads, pos_dim) + k = k.reshape(kv_len, num_heads, head_dim) + # v = v.reshape(kv_len, num_heads, head_dim // 2).permute(1, 0, 2) + v = v.reshape(kv_len, num_heads, head_dim // 2) + v = self.my_permute_pqv(v) + # v is (num_heads, kv_len, head_dim//2) e.g., (8, 80, 12) + + # q = q.permute(1, 0, 2) # (head, time1, head_dim) + # p = p.permute(1, 0, 2) # (head, time1, pos_dim) + # k = k.permute(1, 2, 0) # (head, d_k, time2) + + q = self.my_permute_pqv(q) # (head, time1, head_dim), e.g., (8, 16, 24) + p = self.my_permute_pqv(p) # (head, time1, pos_dim), e.g., (8, 16, 4) + k = self.my_permute_k_pos(k) # (head, d_k, time2) e.g., (8, 24, 80) + + seq_len2 = 2 * seq_len - 1 + left_context_len + # pos = pos.reshape(seq_len2, num_heads, pos_dim).permute(1, 2, 0) + # pos shape now: (head, pos_dim, seq_len2) + + pos = pos.reshape(seq_len2, num_heads, pos_dim) + pos = self.my_permute_k_pos( + pos + ) # (head, pos_dim, seq_len2), e.g, (8, 4, 95) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) ,e.g., (1, 8, 16, 95) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + + if not self.is_pnnx: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + else: + pos_weights = pos_weights.as_strided( + (num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1) - pos_weights.stride(2), + pos_weights.stride(2), + ), + storage_offset=pos_weights.stride(2) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + # (8, 16, 12) + + if not self.is_pnnx: + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + else: + attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) + attn_output = attn_output.reshape(seq_len, bsz, attention_dim // 2) + # We have changed InnerProduct in ncnn to treat + # (seq_len, bsz, attention_dim//2) as + # (seq_len, attention_dim//2) + + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + return attn_output, attn_output_weights, cached_key, cached_val + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + def streaming_forward2( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + cached_val: cached attention value tensor of left context. + Returns: + - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + - updated cached attention value tensor of left context. + """ + num_heads = self.num_heads + + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) + + if not self.is_pnnx: + (seq_len, bsz, embed_dim) = x.shape + else: + seq_len = self.x_size + bsz = 1 + embed_dim = self.embed_dim + + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + + assert cached_val.shape[0] == self.left_context_len, ( + cached_val.shape[0], + self.left_context_len, + ) + + left_context_len = self.left_context_len + assert left_context_len > 0, left_context_len + v = torch.cat([cached_val, v], dim=0) + cached_val = v[-left_context_len:] + + seq_len2 = left_context_len + seq_len + if not self.is_pnnx: + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) + else: + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2) + # v = v.permute(1, 0, 2) + v = self.my_permute_pqv(v) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not self.is_pnnx: + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + else: + attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) + attn_output = attn_output.reshape(seq_len, bsz, self.attention_dim // 2) + # We have changed InnerProduct in ncnn to ignore bsz + # when invoking self.out_proj2(attn_output) + + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output), cached_val + + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) + # attn_covar: (num_heads, head_dim, head_dim) + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + is_pnnx: bool = False, + x_size: int = 0, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + # Will pad cached left context + self.lorder = kernel_size - 1 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + self.is_pnnx = is_pnnx + self.x_size = x_size + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch: + (batch, #time), contains bool in masked positions. + cache: Cached left context for depthwise_conv, with shape of + (batch, channels, #kernel_size-1). Only used in real streaming decoding. + + Returns: + A tuple of 2 tensors: + - Output tensor (#time, batch, channels). + - New cached left context, with shape of (batch, channels, #kernel_size-1). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( + cache.shape, + (x.size(0), x.size(1), self.lorder), + ) + x = torch.cat([cache, x], dim=2) + + cache = x[:, :, self.x_size :] + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + # After this layer (N, 1, T, C) -> (N, layer1_channels, T-2, C) + ActivationBalancer(layer1_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + # After this layer (N, layer1_channels, T-2, C) -> (N, layer2_channels, ((T-2) - 3)//2+1, (C-3)//2+1) + # i.e., (N, layer2_channels, (T-5)//2+1, (C-3)//2+1) + # i.e., (N, layer2_channels, (T-3)//2, (C-1)//2) + ActivationBalancer(layer2_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + # After this layer, (N, layer2_channels, (T-3)//2, (C-1)//2) + # -> + # (N, layer3_channels, (T-3)//2-2, ((C-1)//2 - 3)//2 + 1) + # (N, layer3_channels, (T-7)//2, (C-3)//4) + ActivationBalancer(layer3_channels, channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 47 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + decode_chunk_size=4, + ) + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + +def _test_pooling_module(): + N, S, C = 2, 12, 32 + chunk_len = 4 + m = PoolingModule(d_model=C) + + # test chunk-wise forward with padding_mask + x = torch.randn(S, N, C) + y = m(x) + cached_len = torch.zeros(N, dtype=torch.int32) + cached_avg = torch.zeros(N, C) + for i in range(S // chunk_len): + start = i * chunk_len + end = start + chunk_len + x_chunk = x[start:end] + y_chunk, cached_len, cached_avg = m.streaming_forward( + x_chunk, + cached_len=cached_len, + cached_avg=cached_avg, + ) + assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) + + +def _test_state_stack_unstack(): + m = Zipformer( + num_features=80, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + zipformer_downsampling_factors=(4, 8), + num_left_chunks=2, + decode_chunk_size=8, + ) + s1 = m.get_init_state() + s2 = m.get_init_state() + states = stack_states([s1, s2]) + new_s1, new_s2 = unstack_states(states) + for i in range(m.num_encoders * 7): + for x, y in zip(s1[i], new_s1[i]): + assert torch.equal(x, y) + for x, y in zip(s2[i], new_s2[i]): + assert torch.equal(x, y) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main() + _test_conv2d_subsampling() + _test_pooling_module() + _test_state_stack_unstack() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py new file mode 120000 index 000000000..a3a1584d1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless8/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py new file mode 100755 index 000000000..35158ced4 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from librispeech import LibriSpeech +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--right-padding", + type=int, + default=64, + help="Padding frames at the end of features", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += params.right_padding + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.right_padding), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-right-padding-{params.right_padding}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py new file mode 100644 index 000000000..a4f52ad7f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech_asrmodule import GigaSpeechAsrDataModule +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--right-padding", + type=int, + default=64, + help="Padding frames at the end of features", + ) + + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + feature_lens += params.right_padding + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.right_padding), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + """ + This scripts test a libri model with libri BPE + on Gigaspeech. + """ + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech") + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-right-padding-{params.right_padding}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py new file mode 120000 index 000000000..2b4596e0b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_stream.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py new file mode 120000 index 000000000..4e79a10e0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py new file mode 120000 index 000000000..24f414dd1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py new file mode 120000 index 000000000..c0e71accf --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn-zh.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py new file mode 100755 index 000000000..f5589d1b2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py new file mode 120000 index 000000000..137fa8cec --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-onnx.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py new file mode 120000 index 000000000..6a009311c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py new file mode 120000 index 000000000..6c6b08d3f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless8/gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py new file mode 120000 index 000000000..54f18a4f0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_asrmodule.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py new file mode 120000 index 000000000..fdfa6ce4b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py new file mode 120000 index 000000000..c427e7709 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/jit_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py new file mode 120000 index 000000000..44ecf1780 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_export.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/jit_trace_export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py new file mode 120000 index 000000000..762d38b73 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/jit_trace_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py new file mode 120000 index 000000000..2a9c1ca5f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py new file mode 120000 index 000000000..7c22bc4b7 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/librispeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless8/librispeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py new file mode 120000 index 000000000..17ced2998 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py new file mode 120000 index 000000000..8ed81ba1c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py new file mode 120000 index 000000000..c780015d1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_model_wrapper.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py new file mode 120000 index 000000000..da0236c2d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py new file mode 120000 index 000000000..6c5f3fc3e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py new file mode 120000 index 000000000..4c519b771 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py new file mode 120000 index 000000000..420bc4149 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py new file mode 120000 index 000000000..b6cc7dc13 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py new file mode 120000 index 000000000..d137d28ad --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py new file mode 120000 index 000000000..dee9005d0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py new file mode 100644 index 000000000..78713f920 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from librispeech import LibriSpeech +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py new file mode 120000 index 000000000..1a9ba93e6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/test_model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/test_model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py new file mode 100755 index 000000000..09e8a512f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -0,0 +1,1370 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from gigaspeech import GigaSpeech +from joiner import Joiner +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--giga-prob", + type=float, + default=0.5, + help="The probability to select a batch from the GigaSpeech dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() + tot_loss = MetricsTracker() + + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + name = "libri" if idx == 0 else "giga" + logging.info(f"{name} reaches end of dataloader") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + libri = is_libri(batch["supervisions"]["cut"][0]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, sp: spm.SentencePieceProcessor +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + return False + + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + train_cuts = filter_short_and_long_utterances(train_cuts, sp) + + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the XL subset of GigaSpeech (10k hours)") + train_giga_cuts = gigaspeech.train_XL_cuts() + else: + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() + + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) + train_giga_cuts = train_giga_cuts.repeat(times=None) + + if args.enable_musan: + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + giga_train_dl=giga_train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + assert 0 <= args.giga_prob < 1, args.giga_prob + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py new file mode 120000 index 000000000..3c3280b68 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/train2.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py new file mode 120000 index 000000000..be9e75bfa --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py new file mode 120000 index 000000000..d3625f478 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless8/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py new file mode 120000 index 000000000..3ba9ada4f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py new file mode 100755 index 000000000..e07777c9f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -0,0 +1,797 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless8/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from librispeech import LibriSpeech +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py new file mode 100755 index 000000000..d4a228b47 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless8/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless8/decode.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), + strict=False, + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py new file mode 120000 index 000000000..5242c652a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/gigaspeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/gigaspeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py new file mode 100755 index 000000000..129497d5a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless8/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless8/exp/cpu_jit.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py b/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py new file mode 120000 index 000000000..b76723bf5 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/librispeech.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/librispeech.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py new file mode 100644 index 000000000..39a360796 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -0,0 +1,220 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import penalize_abs_values_gt + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + decoder_giga: Optional[nn.Module] = None, + joiner_giga: Optional[nn.Module] = None, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.decoder_giga = decoder_giga + self.joiner_giga = joiner_giga + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + if decoder_giga is not None: + self.simple_am_proj_giga = nn.Linear(encoder_dim, vocab_size) + self.simple_lm_proj_giga = nn.Linear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + libri: bool = True, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + libri: + True to use the decoder and joiner for the LibriSpeech dataset. + False to use the decoder and joiner for the GigaSpeech dataset. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + if libri: + decoder = self.decoder + simple_lm_proj = self.simple_lm_proj + simple_am_proj = self.simple_am_proj + joiner = self.joiner + else: + decoder = self.decoder_giga + simple_lm_proj = self.simple_lm_proj_giga + simple_am_proj = self.simple_am_proj_giga + joiner = self.joiner_giga + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = simple_lm_proj(decoder_out) + am = simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=joiner.encoder_proj(encoder_out), + lm=joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py new file mode 100755 index 000000000..486d9d74e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless8/export.py \ + --exp-dir ./pruned_transducer_stateless8/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless8/pretrained.py \ + --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless8/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless8/exp/pretrained.pt is generated by +./pruned_transducer_stateless8/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params, enable_giga=False) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py new file mode 100755 index 000000000..b0abad5ae --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -0,0 +1,1362 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +cd egs/librispeech/ASR/ +./prepare.sh +./prepare_giga_speech.sh + +./pruned_transducer_stateless8/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless8/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless8/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless8/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from gigaspeech import GigaSpeech +from joiner import Joiner +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from librispeech import LibriSpeech +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless8/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--giga-prob", + type=float, + default=0.5, + help="The probability to select a batch from the GigaSpeech dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model( + params: AttributeDict, + enable_giga: bool = True, +) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + if enable_giga: + logging.info("Use giga") + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + else: + logging.info("Disable giga") + decoder_giga = None + joiner_giga = None + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + decoder_giga=decoder_giga, + joiner_giga=joiner_giga, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def is_libri(c: Cut) -> bool: + """Return True if this cut is from the LibriSpeech dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of GigaSpeech to dict(origin='giga') + See ../local/preprocess_gigaspeech.py. + """ + return c.supervisions[0].custom is None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + libri = is_libri(supervisions["cut"][0]) + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + libri=libri, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + giga_train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + giga_train_dl: + Dataloader for the GigaSpeech training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + For selecting which dataset to use. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + libri_tot_loss = MetricsTracker() + giga_tot_loss = MetricsTracker() + tot_loss = MetricsTracker() + + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.giga_prob, params.giga_prob] + + iter_libri = iter(train_dl) + iter_giga = iter(giga_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + + try: + batch = next(dl) + except StopIteration: + name = "libri" if idx == 0 else "giga" + logging.info(f"{name} reaches end of dataloader") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + libri = is_libri(batch["supervisions"]["cut"][0]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + if libri: + libri_tot_loss = ( + libri_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "libri" # for logging only + else: + giga_tot_loss = ( + giga_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "giga" + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"libri_tot_loss[{libri_tot_loss}], " + f"giga_tot_loss[{giga_tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + libri_tot_loss.write_summary( + tb_writer, "train/libri_tot_", params.batch_idx_train + ) + giga_tot_loss.write_summary( + tb_writer, "train/giga_tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, sp: spm.SentencePieceProcessor +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=True) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + train_cuts = filter_short_and_long_utterances(train_cuts, sp) + + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + # XL 10k hours + # L 2.5k hours + # M 1k hours + # S 250 hours + # XS 10 hours + # DEV 12 hours + # Test 40 hours + if params.full_libri: + logging.info("Using the XL subset of GigaSpeech (10k hours)") + train_giga_cuts = gigaspeech.train_XL_cuts() + else: + logging.info("Using the S subset of GigaSpeech (250 hours)") + train_giga_cuts = gigaspeech.train_S_cuts() + + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) + train_giga_cuts = train_giga_cuts.repeat(times=None) + + if args.enable_musan: + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_giga_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + giga_train_dl=giga_train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + assert 0 <= args.giga_prob < 1, args.giga_prob + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 01be7090b..53f383c99 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded:

-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
+streaming_models/
+|-- lang_bpe
+|   |-- L.pt
+|   |-- Linv.pt
 |   |-- bpe.model
 |   |-- tokens.txt
 |   `-- words.txt
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
index ff4c91446..5fe92172e 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
@@ -309,36 +309,26 @@ class Conformer(Transformer):
 
                 # start chunk_by_chunk decoding
                 offset = 0
-                for cur in range(
-                    0, num_frames - embed_left_context + 1, stride
-                ):
+                for cur in range(0, num_frames - embed_left_context + 1, stride):
                     end = min(cur + decoding_window, num_frames)
                     cur_feature = feature[:, cur:end, :]
                     cur_feature = self.encoder_embed(cur_feature)
-                    cur_embed, cur_pos_emb = self.encoder_pos(
-                        cur_feature, offset
-                    )
-                    cur_embed = cur_embed.permute(
-                        1, 0, 2
-                    )  # (B, T, F) -> (T, B, F)
+                    cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset)
+                    cur_embed = cur_embed.permute(1, 0, 2)  # (B, T, F) -> (T, B, F)
 
                     cur_T = cur_feature.size(1)
                     if cur == 0:
                         # for first chunk extract the central pos embedding
-                        pos_emb_central = cur_pos_emb[
-                            0, (chunk_size - 1), :
-                        ].view(1, 1, -1)
+                        pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view(
+                            1, 1, -1
+                        )
                         cur_T -= 1
                     pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
                     pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
                     assert pos_emb_positive[-1].size(0) == cur_T
 
-                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
-                        0
-                    )
-                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
-                        0
-                    )
+                    pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0)
+                    pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0)
                     cur_pos_emb = torch.cat(
                         [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
                         dim=1,
@@ -413,9 +403,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -431,22 +419,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -480,9 +462,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -554,9 +534,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -736,9 +714,7 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -755,9 +731,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -783,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, offset: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -813,9 +785,7 @@ class RelPositionalEncoding(torch.nn.Module):
             pos_emb = torch.cat(
                 [
                     pos_emb[:, : (x_T - 1)],
-                    self.pe[0, self.pe.size(1) // 2].view(
-                        1, 1, self.pe.size(-1)
-                    ),
+                    self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)),
                     pos_emb[:, -(x_T - 1) :],  # noqa: E203
                 ],
                 dim=1,
@@ -1050,9 +1020,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1120,31 +1090,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1185,24 +1146,16 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
-        matrix_bd = self.rel_shift(
-            matrix_bd, offset=offset
-        )  # [B, head, time1, time2]
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, offset=offset)  # [B, head, time1, time2]
         attn_output_weights = (
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1236,13 +1189,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
index a74c51836..a26d0b789 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
@@ -28,6 +28,7 @@ import torch
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+
 from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
 from icefall.checkpoint import average_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
@@ -86,8 +87,7 @@ def get_parser():
         "--tailing-num-frames",
         type=int,
         default=20,
-        help="tailing dummy frames padded to the right,"
-        "only used during decoding",
+        help="tailing dummy frames padded to the right, only used during decoding",
     )
 
     parser.add_argument(
@@ -248,13 +248,9 @@ def decode_one_batch(
     maxlen = nnet_output.size(1)
     topk_prob, topk_index = nnet_output.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
-    topk_index = topk_index.masked_fill_(
-        memory_key_padding_mask, 0
-    )  # (B, maxlen)
+    topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0)  # (B, maxlen)
     token_ids = [token_id.tolist() for token_id in topk_index]
-    token_ids = [
-        remove_duplicates_and_blank(token_id) for token_id in token_ids
-    ]
+    token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids]
     hyps = bpe_model.decode(token_ids)
     hyps = [s.split() for s in hyps]
     return {key: hyps}
@@ -337,9 +333,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return results
 
@@ -364,8 +358,7 @@ def save_results(
         -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
     for key, results in results_dict.items():
         recog_path = (
-            params.exp_dir
-            / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
         )
         store_transcripts(filename=recog_path, texts=results)
         if enable_log:
@@ -374,8 +367,7 @@ def save_results(
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
         errs_filename = (
-            params.exp_dir
-            / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
+            params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
         )
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
@@ -384,9 +376,7 @@ def save_results(
             test_set_wers[key] = wer
 
         if enable_log:
-            logging.info(
-                "Wrote detailed error stats to {}".format(errs_filename)
-            )
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
     errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@@ -474,9 +464,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -507,9 +495,7 @@ def main():
             simulate_streaming=params.simulate_streaming,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
index e41b7ea78..bb55ed6bb 100755
--- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py
@@ -30,6 +30,7 @@ import torch.multiprocessing as mp
 import torch.nn as nn
 from asr_datamodule import LibriSpeechAsrDataModule
 from conformer import Conformer
+from lhotse.cut import Cut
 from lhotse.utils import fix_random_seed
 from torch import Tensor
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -405,9 +406,7 @@ def compute_loss(
             #
             # See https://github.com/k2-fsa/icefall/issues/97
             # for more details
-            unsorted_token_ids = graph_compiler.texts_to_ids(
-                supervisions["text"]
-            )
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
             att_loss = mmodel.decoder_forward(
                 encoder_memory,
                 memory_mask,
@@ -436,9 +435,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
-        .sum()
-        .item()
+        ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
     )
 
     return loss, info
@@ -551,9 +548,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -651,8 +646,23 @@ def run(rank, world_size, args):
         optimizer.load_state_dict(checkpoints["optimizer"])
 
     librispeech = LibriSpeechAsrDataModule(args)
-    train_dl = librispeech.train_dataloaders()
-    valid_dl = librispeech.valid_dataloaders()
+
+    if params.full_libri:
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    train_dl = librispeech.train_dataloaders(train_cuts)
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
 
     scan_pessimistic_batches_for_oom(
         model=model,
@@ -668,9 +678,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
index bc78e4a41..0c87fdf1b 100644
--- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
+++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
                 norm=decoder_norm,
             )
 
-            self.decoder_output_layer = torch.nn.Linear(
-                d_model, self.decoder_num_class
-            )
+            self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
 
             self.decoder_criterion = LabelSmoothingLoss()
         else:
@@ -286,23 +284,17 @@ class Transformer(nn.Module):
         """
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device)
         ys_out_pad = ys_out_pad.to(device)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -363,23 +355,17 @@ class Transformer(nn.Module):
 
         ys_in = add_sos(token_ids, sos_id=sos_id)
         ys_in = [torch.tensor(y) for y in ys_in]
-        ys_in_pad = pad_sequence(
-            ys_in, batch_first=True, padding_value=float(eos_id)
-        )
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
 
         ys_out = add_eos(token_ids, eos_id=eos_id)
         ys_out = [torch.tensor(y) for y in ys_out]
-        ys_out_pad = pad_sequence(
-            ys_out, batch_first=True, padding_value=float(-1)
-        )
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
 
         device = memory.device
         ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
         ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
 
-        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
-            device
-        )
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
 
         tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
         # TODO: Use length information to create the decoder padding mask
@@ -652,9 +638,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
@@ -856,9 +840,7 @@ def encoder_padding_mask(
         1,
     ).to(torch.int32)
 
-    lengths = [
-        0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
-    ]
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
     for idx in range(supervision_segments.size(0)):
         # Note: TorchScript doesn't allow to unpack tensors as tuples
         sequence_idx = supervision_segments[idx, 0].item()
@@ -879,9 +861,7 @@ def encoder_padding_mask(
     return mask
 
 
-def decoder_padding_mask(
-    ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
     """Generate a length mask for input.
 
     The masked position are filled with True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md
index 94d4ed6a3..b1e01a218 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md
@@ -1,4 +1,4 @@
 
 Please visit
-
+
 for how to run this recipe.
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 355ccc99a..c5787835d 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -86,8 +86,7 @@ class LibriSpeechAsrDataModule:
             "--full-libri",
             type=str2bool,
             default=True,
-            help="When enabled, use 960h LibriSpeech. "
-            "Otherwise, use 100h subset.",
+            help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
         )
         group.add_argument(
             "--manifest-dir",
@@ -224,13 +223,9 @@ class LibriSpeechAsrDataModule:
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             logging.info("About to get Musan cuts")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -252,9 +247,7 @@ class LibriSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -298,9 +291,7 @@ class LibriSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -356,9 +347,7 @@ class LibriSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -425,6 +414,16 @@ class LibriSpeechAsrDataModule:
             self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
         )
 
+    @lru_cache()
+    def train_all_shuf_cuts(self) -> CutSet:
+        logging.info(
+            "About to get the shuffled train-clean-100, \
+            train-clean-360 and train-other-500 cuts"
+        )
+        return load_manifest_lazy(
+            self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
+        )
+
     @lru_cache()
     def dev_clean_cuts(self) -> CutSet:
         logging.info("About to get dev-clean cuts")
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 7d0cd0bf3..92529e06c 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -336,9 +336,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +398,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -467,9 +463,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -498,9 +492,7 @@ def main():
             G=G,
         )
 
-        save_results(
-            params=params, test_set_name=test_set, results_dict=results_dict
-        )
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=500, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
-                for _ in range(5)
-            ]
+            [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index 2baeb6bba..fde724866 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -144,10 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
index 6b37d5c23..0aa1587ba 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py
@@ -173,7 +173,7 @@ def get_params() -> AttributeDict:
         {
             "exp_dir": Path("tdnn_lstm_ctc/exp"),
             "lang_dir": Path("data/lang_phone"),
-            "lr": 1e-3,
+            "lr": 1e-4,
             "feature_dim": 80,
             "weight_decay": 5e-4,
             "subsampling_factor": 3,
@@ -355,9 +355,7 @@ def compute_loss(
     info["utt_duration"] = supervisions["num_frames"].sum().item()
     # averaged padding proportion over utterances
     info["utt_pad_proportion"] = (
-        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2))
-        .sum()
-        .item()
+        ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item()
     )
 
     return loss, info
@@ -470,9 +468,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
@@ -561,10 +557,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py
index 11032f31a..b45b6a9d8 100644
--- a/egs/librispeech/ASR/transducer/beam_search.py
+++ b/egs/librispeech/ASR/transducer/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -123,9 +121,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -157,9 +153,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py
index 5f233df87..8d379d1fa 100755
--- a/egs/librispeech/ASR/transducer/decode.py
+++ b/egs/librispeech/ASR/transducer/decode.py
@@ -228,9 +228,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -245,9 +243,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -318,9 +314,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -331,18 +325,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -352,10 +342,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py
index 5a5db30c4..6db0272f0 100755
--- a/egs/librispeech/ASR/transducer/export.py
+++ b/egs/librispeech/ASR/transducer/export.py
@@ -238,9 +238,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 1db2df648..511610245 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -188,10 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -249,9 +248,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -287,9 +284,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py
index 2a165b0c1..fe8732301 100644
--- a/egs/librispeech/ASR/transducer/rnn.py
+++ b/egs/librispeech/ASR/transducer/rnn.py
@@ -117,12 +117,8 @@ class LayerNormLSTMCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(4 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
@@ -348,9 +344,7 @@ class LayerNormLSTM(nn.Module):
             device=device,
             dtype=dtype,
         )
-        first_layer = LayerNormLSTMLayer(
-            input_size=input_size, **factory_kwargs
-        )
+        first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs)
         layers = [first_layer]
         for i in range(1, num_layers):
             layers.append(
@@ -385,9 +379,7 @@ class LayerNormLSTM(nn.Module):
             - List[(next_h, next_c)] containing the hidden states for all layers
 
         """
-        output_states = torch.jit.annotate(
-            List[Tuple[torch.Tensor, torch.Tensor]], []
-        )
+        output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
         output = input
         for i, rnn_layer in enumerate(self.layers):
             state = states[i]
@@ -456,12 +448,8 @@ class LayerNormGRUCell(nn.Module):
         )
 
         if bias:
-            self.bias_ih = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
-            self.bias_hh = nn.Parameter(
-                torch.empty(3 * hidden_size, **factory_kwargs)
-            )
+            self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
+            self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs))
         else:
             self.register_parameter("bias_ih", None)
             self.register_parameter("bias_hh", None)
diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py
index 8591e2d8a..74c94cc70 100755
--- a/egs/librispeech/ASR/transducer/test_rnn.py
+++ b/egs/librispeech/ASR/transducer/test_rnn.py
@@ -254,9 +254,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -303,9 +301,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
         for name, self_param in self_layer.cell.named_parameters():
             getattr(torch_layer, f"{name}_l0").copy_(self_param)
 
-    torch_y, (torch_h, torch_c) = torch_layer(
-        x_clone, (h.unsqueeze(0), c.unsqueeze(0))
-    )
+    torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0)))
     assert_allclose(self_y, torch_y)
     assert_allclose(self_h, torch_h)
     assert_allclose(self_c, torch_c)
@@ -594,9 +590,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
 
     assert_allclose(self_h, torch_h, atol=1e-5)
 
-    (
-        self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
-    ).sum().backward()
+    (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward()
     (
         torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
     ).sum().backward()
@@ -718,9 +712,7 @@ def test_layernorm_gru_forward(device="cpu"):
     T = torch.randint(low=2, high=100, size=(1,))
 
     x = torch.rand(N, T, input_size, device=device).requires_grad_()
-    states = [
-        torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
-    ]
+    states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
 
     x_clone = x.detach().clone().requires_grad_()
 
diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py
index 1dd65eddb..29625754e 100755
--- a/egs/librispeech/ASR/transducer/train.py
+++ b/egs/librispeech/ASR/transducer/train.py
@@ -396,9 +396,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -520,9 +518,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -618,10 +614,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
@@ -659,9 +655,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py
index 3531a9633..5342c3e8c 100644
--- a/egs/librispeech/ASR/transducer_lstm/beam_search.py
+++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py
@@ -38,9 +38,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
     blank_id = model.decoder.blank_id
     device = model.device
 
-    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
-        1, 1
-    )
+    sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1)
     decoder_out, (h, c) = model.decoder(sos)
     T = encoder_out.size(1)
     t = 0
@@ -124,9 +122,7 @@ def beam_search(
     max_u = 20000  # terminate after this number of steps
     u = 0
 
-    cache: Dict[
-        str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
-    ] = {}
+    cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
 
     while t < T and u < max_u:
         # fmt: off
@@ -158,9 +154,9 @@ def beam_search(
             cached_key = "_".join(map(str, y_star.ys))
 
             if cached_key not in cache:
-                decoder_input = torch.tensor(
-                    [y_star.ys[-1]], device=device
-                ).reshape(1, 1)
+                decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape(
+                    1, 1
+                )
 
                 decoder_out, decoder_state = model.decoder(
                     decoder_input,
diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py
index 604235e2a..806b68f40 100755
--- a/egs/librispeech/ASR/transducer_lstm/decode.py
+++ b/egs/librispeech/ASR/transducer_lstm/decode.py
@@ -225,9 +225,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -242,9 +240,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -315,9 +311,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -328,18 +322,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -349,10 +339,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py
index 3dc992dd2..038d80077 100644
--- a/egs/librispeech/ASR/transducer_lstm/encoder.py
+++ b/egs/librispeech/ASR/transducer_lstm/encoder.py
@@ -48,9 +48,7 @@ class LstmEncoder(EncoderInterface):
         if vgg_frontend:
             self.encoder_embed = VggSubsampling(num_features, real_hidden_size)
         else:
-            self.encoder_embed = Conv2dSubsampling(
-                num_features, real_hidden_size
-            )
+            self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size)
 
         self.rnn = nn.LSTM(
             input_size=hidden_size,
diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py
index cdb801e79..a6f2bd08c 100755
--- a/egs/librispeech/ASR/transducer_lstm/train.py
+++ b/egs/librispeech/ASR/transducer_lstm/train.py
@@ -400,9 +400,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -524,9 +522,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -624,27 +620,17 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeechAsrDataModule(args)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
         return 1.0 <= c.duration <= 20.0
 
-    num_in_total = len(train_cuts)
-
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 
-    num_left = len(train_cuts)
-    num_removed = num_in_total - num_left
-    removed_percent = num_removed / num_in_total * 100
-
-    logging.info(f"Before removing short and long utterances: {num_in_total}")
-    logging.info(f"After removing short and long utterances: {num_left}")
-    logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
-
     train_dl = librispeech.train_dataloaders(train_cuts)
 
     valid_cuts = librispeech.dev_clean_cuts()
@@ -665,9 +651,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py
index f143611ea..65f2c58d8 100644
--- a/egs/librispeech/ASR/transducer_stateless/alignment.py
+++ b/egs/librispeech/ASR/transducer_stateless/alignment.py
@@ -193,9 +193,7 @@ def force_alignment(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_active_items, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py
index ea985f30d..1d79eef9d 100644
--- a/egs/librispeech/ASR/transducer_stateless/beam_search.py
+++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py
@@ -316,9 +316,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -478,9 +478,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -496,9 +494,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -786,9 +782,7 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
@@ -887,9 +881,7 @@ def _deprecated_modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -959,9 +951,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
index 48769e9d1..f479389df 100755
--- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py
@@ -124,8 +124,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -162,9 +161,7 @@ def compute_alignments(
 
         feature_lens = supervisions["num_frames"].to(device)
 
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
         batch_size = encoder_out.size(0)
 
@@ -204,9 +201,7 @@ def compute_alignments(
         if batch_idx % 2 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
 
     return CutSet.from_cuts(cuts)
 
diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py
index cde52c9fc..90b722bde 100644
--- a/egs/librispeech/ASR/transducer_stateless/conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/conformer.py
@@ -24,7 +24,7 @@ import torch
 from torch import Tensor, nn
 from transformer import Transformer
 
-from icefall.utils import make_pad_mask, subsequent_chunk_mask
+from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
 
 
 class Conformer(Transformer):
@@ -154,7 +154,8 @@ class Conformer(Transformer):
         # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
         lengths = (((x_lens - 1) >> 1) - 1) >> 1
 
-        assert x.size(0) == lengths.max().item()
+        if not is_jit_tracing():
+            assert x.size(0) == lengths.max().item()
 
         src_key_padding_mask = make_pad_mask(lengths)
 
@@ -209,10 +210,7 @@ class Conformer(Transformer):
 
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -361,6 +359,11 @@ class Conformer(Transformer):
 
             assert x.size(0) == lengths.max().item()
 
+            if chunk_size < 0:
+                # use full attention
+                chunk_size = x.size(0)
+                left_context = -1
+
             num_left_chunks = -1
             if left_context >= 0:
                 assert left_context % chunk_size == 0
@@ -421,9 +424,7 @@ class ConformerEncoderLayer(nn.Module):
         causal: bool = False,
     ) -> None:
         super(ConformerEncoderLayer, self).__init__()
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             nn.Linear(d_model, dim_feedforward),
@@ -439,22 +440,16 @@ class ConformerEncoderLayer(nn.Module):
             nn.Linear(dim_feedforward, d_model),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
-        self.norm_ff_macaron = nn.LayerNorm(
-            d_model
-        )  # for the macaron style FNN module
+        self.norm_ff_macaron = nn.LayerNorm(d_model)  # for the macaron style FNN module
         self.norm_ff = nn.LayerNorm(d_model)  # for the FNN module
         self.norm_mha = nn.LayerNorm(d_model)  # for the MHA module
 
         self.ff_scale = 0.5
 
         self.norm_conv = nn.LayerNorm(d_model)  # for the CNN module
-        self.norm_final = nn.LayerNorm(
-            d_model
-        )  # for the final output of the block
+        self.norm_final = nn.LayerNorm(d_model)  # for the final output of the block
 
         self.dropout = nn.Dropout(dropout)
 
@@ -486,9 +481,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -514,9 +507,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = residual + self.dropout(src)
 
         if not self.normalize_before:
@@ -581,9 +572,7 @@ class ConformerEncoderLayer(nn.Module):
         residual = src
         if self.normalize_before:
             src = self.norm_ff_macaron(src)
-        src = residual + self.ff_scale * self.dropout(
-            self.feed_forward_macaron(src)
-        )
+        src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
         if not self.normalize_before:
             src = self.norm_ff_macaron(src)
 
@@ -625,9 +614,7 @@ class ConformerEncoderLayer(nn.Module):
         if self.normalize_before:
             src = self.norm_conv(src)
 
-        src, conv_cache = self.conv_module(
-            src, states[1], right_context=right_context
-        )
+        src, conv_cache = self.conv_module(src, states[1], right_context=right_context)
         states[1] = conv_cache
         src = residual + self.dropout(src)
 
@@ -779,11 +766,17 @@ class RelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
+        if is_jit_tracing():
+            # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
+            # It assumes that the maximum input won't have more than
+            # 10k frames.
+            #
+            # TODO(fangjun): Use torch.jit.script() for this module
+            max_len = 10000
+
         self.d_model = d_model
         self.xscale = math.sqrt(self.d_model)
         self.dropout = torch.nn.Dropout(p=dropout_rate)
@@ -798,9 +791,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -826,9 +817,7 @@ class RelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(
-        self, x: torch.Tensor, left_context: int = 0
-    ) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]:
         """Add positional encoding.
 
         Args:
@@ -995,22 +984,34 @@ class RelPositionMultiheadAttention(nn.Module):
           the key, while time1 is for the query).
         """
         (batch_size, num_heads, time1, n) = x.shape
+
         time2 = time1 + left_context
+        if not is_jit_tracing():
+            assert (
+                n == left_context + 2 * time1 - 1
+            ), f"{n} == {left_context} + 2 * {time1} - 1"
 
-        assert (
-            n == left_context + 2 * time1 - 1
-        ), f"{n} == {left_context} + 2 * {time1} - 1"
+        if is_jit_tracing():
+            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+            cols = torch.arange(time2)
+            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+            indexes = rows + cols
 
-        # Note: TorchScript requires explicit arg for stride()
-        batch_stride = x.stride(0)
-        head_stride = x.stride(1)
-        time1_stride = x.stride(2)
-        n_stride = x.stride(3)
-        return x.as_strided(
-            (batch_size, num_heads, time1, time2),
-            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
-            storage_offset=n_stride * (time1 - 1),
-        )
+            x = x.reshape(-1, n)
+            x = torch.gather(x, dim=1, index=indexes)
+            x = x.reshape(batch_size, num_heads, time1, time2)
+            return x
+        else:
+            # Note: TorchScript requires explicit arg for stride()
+            batch_stride = x.stride(0)
+            head_stride = x.stride(1)
+            time1_stride = x.stride(2)
+            n_stride = x.stride(3)
+            return x.as_strided(
+                (batch_size, num_heads, time1, time2),
+                (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+                storage_offset=n_stride * (time1 - 1),
+            )
 
     def multi_head_attention_forward(
         self,
@@ -1081,20 +1082,23 @@ class RelPositionMultiheadAttention(nn.Module):
         """
 
         tgt_len, bsz, embed_dim = query.size()
-        assert embed_dim == embed_dim_to_check
-        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+        if not is_jit_tracing():
+            assert embed_dim == embed_dim_to_check
+            assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
 
         head_dim = embed_dim // num_heads
-        assert (
-            head_dim * num_heads == embed_dim
-        ), "embed_dim must be divisible by num_heads"
+        if not is_jit_tracing():
+            assert (
+                head_dim * num_heads == embed_dim
+            ), "embed_dim must be divisible by num_heads"
+
         scaling = float(head_dim) ** -0.5
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1163,31 +1167,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1210,7 +1205,8 @@ class RelPositionMultiheadAttention(nn.Module):
         q = q.transpose(0, 1)  # (batch, time1, head, d_k)
 
         pos_emb_bsz = pos_emb.size(0)
-        assert pos_emb_bsz in (1, bsz)  # actually it is 1
+        if not is_jit_tracing():
+            assert pos_emb_bsz in (1, bsz)  # actually it is 1
         p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
 
         # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
@@ -1228,14 +1224,10 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
 
         matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
 
@@ -1243,15 +1235,14 @@ class RelPositionMultiheadAttention(nn.Module):
             matrix_ac + matrix_bd
         ) * scaling  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
-        assert list(attn_output_weights.size()) == [
-            bsz * num_heads,
-            tgt_len,
-            src_len,
-        ]
+        if not is_jit_tracing():
+            assert list(attn_output_weights.size()) == [
+                bsz * num_heads,
+                tgt_len,
+                src_len,
+            ]
 
         if attn_mask is not None:
             if attn_mask.dtype == torch.bool:
@@ -1290,9 +1281,7 @@ class RelPositionMultiheadAttention(nn.Module):
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1302,15 +1291,14 @@ class RelPositionMultiheadAttention(nn.Module):
         )
 
         attn_output = torch.bmm(attn_output_weights, v)
-        assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+
+        if not is_jit_tracing():
+            assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1418,16 +1406,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py
index 74bba9cad..42125e19f 100755
--- a/egs/librispeech/ASR/transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +229,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -387,18 +379,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -408,10 +396,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -450,9 +435,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py
index fbc2373a9..a182d91e2 100644
--- a/egs/librispeech/ASR/transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/transducer_stateless/decoder.py
@@ -87,9 +87,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py
index 8bd0bdea1..89359f1a4 100755
--- a/egs/librispeech/ASR/transducer_stateless/export.py
+++ b/egs/librispeech/ASR/transducer_stateless/export.py
@@ -109,8 +109,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -244,9 +243,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py
index 93cccbd8c..e1625992d 100644
--- a/egs/librispeech/ASR/transducer_stateless/joiner.py
+++ b/egs/librispeech/ASR/transducer_stateless/joiner.py
@@ -60,13 +60,9 @@ class Joiner(nn.Module):
         encoder_out_len: List[int] = encoder_out_len.tolist()
         decoder_out_len: List[int] = decoder_out_len.tolist()
 
-        encoder_out_list = [
-            encoder_out[i, : encoder_out_len[i], :] for i in range(N)
-        ]
+        encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)]
 
-        decoder_out_list = [
-            decoder_out[i, : decoder_out_len[i], :] for i in range(N)
-        ]
+        decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)]
 
         x = [
             e.unsqueeze(1) + d.unsqueeze(0)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index b64521801..915a6069d 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
index b00fc34f1..9af46846a 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
@@ -140,16 +140,13 @@ def main():
                 token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
             )
             word_starting_time = [
-                "{:.2f}".format(i * frame_shift_in_second)
-                for i in word_starting_frames
+                "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames
             ]
 
             words = supervisions["text"][i].split()
 
             assert len(word_starting_frames) == len(words)
-            word_starting_time_dict[cuts[i].id] = list(
-                zip(words, word_starting_time)
-            )
+            word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time))
 
         # This is a demo script and we exit here after processing
         # one batch.
@@ -160,9 +157,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
index d1350c8ab..65b08d425 100755
--- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py
@@ -29,9 +29,7 @@ from conformer import Conformer
 
 def test_conformer():
     feature_dim = 50
-    c = Conformer(
-        num_features=feature_dim, output_dim=256, d_model=128, nhead=4
-    )
+    c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
     batch_size = 5
     seq_len = 20
     # Just make sure the forward pass runs.
diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py
index ae93f3348..8db9b59e7 100755
--- a/egs/librispeech/ASR/transducer_stateless/train.py
+++ b/egs/librispeech/ASR/transducer_stateless/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -422,9 +421,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -545,9 +542,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -646,10 +641,10 @@ def run(rank, world_size, args):
     if params.print_diagnostics:
         diagnostic = diagnostics.attach_diagnostics(model)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
@@ -664,13 +659,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -698,9 +689,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/librispeech/ASR/transducer_stateless/transformer.py
+++ b/egs/librispeech/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
     elif activation == "gelu":
         return nn.functional.gelu
 
-    raise RuntimeError(
-        "activation should be relu/gelu, not {}".format(activation)
-    )
+    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
 
 
 class PositionalEncoding(nn.Module):
diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py
index ac2807241..b05fe2a4d 100755
--- a/egs/librispeech/ASR/transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless2/decode.py
@@ -171,8 +171,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -230,9 +229,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -387,18 +379,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -408,10 +396,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -450,9 +435,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py
index 57c1a6094..d33d02642 100755
--- a/egs/librispeech/ASR/transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/transducer_stateless2/export.py
@@ -104,8 +104,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -176,9 +175,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index 292f77f03..0738f30c0 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py
index ea15c9040..1c3a33870 100755
--- a/egs/librispeech/ASR/transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/transducer_stateless2/train.py
@@ -136,8 +136,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -410,9 +409,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -533,9 +530,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -634,10 +629,10 @@ def run(rank, world_size, args):
     if params.print_diagnostics:
         diagnostic = diagnostics.attach_diagnostics(model)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     def remove_short_and_long_utt(c: Cut):
         # Keep only utterances with duration between 1 second and 20 seconds
@@ -652,13 +647,9 @@ def run(rank, world_size, args):
         num_removed = num_in_total - num_left
         removed_percent = num_removed / num_in_total * 100
 
-        logging.info(
-            f"Before removing short and long utterances: {num_in_total}"
-        )
+        logging.info(f"Before removing short and long utterances: {num_in_total}")
         logging.info(f"After removing short and long utterances: {num_left}")
-        logging.info(
-            f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
-        )
+        logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
     except TypeError as e:
         # You can ignore this error as previous versions of Lhotse work fine
         # for the above code. In recent versions of Lhotse, it uses
@@ -686,9 +677,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
index d596e05cb..5570b30ae 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -249,10 +246,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +369,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -388,18 +380,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -409,10 +397,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -451,9 +436,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
index b6b69d932..3735ef452 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index f297fa2b2..8c7726367 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -167,8 +167,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -259,9 +257,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -334,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
index ef51a7811..1e1188ca6 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py
@@ -41,9 +41,7 @@ def test_dataset():
     print(args)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
index 27912738c..dafccd088 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py
@@ -114,8 +114,7 @@ def get_parser():
         "--full-libri",
         type=str2bool,
         default=True,
-        help="When enabled, use 960h LibriSpeech. "
-        "Otherwise, use 100h subset.",
+        help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
     )
 
     parser.add_argument(
@@ -170,8 +169,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -469,9 +467,7 @@ def compute_loss(
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -635,9 +631,7 @@ def train_one_epoch(
                     f"train/current_{prefix}_",
                     params.batch_idx_train,
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
                 libri_tot_loss.write_summary(
                     tb_writer, "train/libri_tot_", params.batch_idx_train
                 )
@@ -758,10 +752,10 @@ def run(rank, world_size, args):
 
     librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
 
-    train_cuts = librispeech.train_clean_100_cuts()
     if params.full_libri:
-        train_cuts += librispeech.train_clean_360_cuts()
-        train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
 
     train_cuts = filter_short_and_long_utterances(train_cuts)
 
@@ -784,9 +778,7 @@ def run(rank, world_size, args):
     train_giga_cuts = train_giga_cuts.repeat(times=None)
 
     if args.enable_musan:
-        cuts_musan = load_manifest(
-            Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
     else:
         cuts_musan = None
 
@@ -825,9 +817,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/librispeech/ASR/zipformer_mmi/README.md b/egs/librispeech/ASR/zipformer_mmi/README.md
new file mode 100644
index 000000000..e9a37a52a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/README.md
@@ -0,0 +1,26 @@
+This recipe implements Zipformer-MMI model.
+
+See https://k2-fsa.github.io/icefall/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.html for detailed tutorials.
+
+It uses **CTC loss for warm-up** and then switches to MMI loss during training.
+
+For decoding, it uses HP (H is ctc_topo, P is token-level bi-gram) as decoding graph. Supported decoding methods are:
+- **1best**. Extract the best path from the decoding lattice as the decoding result.
+- **nbest**. Extract n paths from the decoding lattice; the path with the highest score is the decoding result.
+- **nbest-rescoring-LG**. Extract n paths from the decoding lattice, rescore them with an word-level 3-gram LM, the path with the highest score is the decoding result.
+- **nbest-rescoring-3-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 3-gram LM, the path with the highest score is the decoding result.
+- **nbest-rescoring-4-gram**. Extract n paths from the decoding lattice, rescore them with an token-level 4-gram LM, the path with the highest score is the decoding result.
+
+Experimental results training on train-clean-100 (epoch-30-avg-10):
+- 1best. 6.43 & 17.44
+- nbest, nbest-scale=1.2, 6.43 & 17.45
+- nbest-rescoring-LG, nbest-scale=1.2, 5.87 & 16.35
+- nbest-rescoring-3-gram,  nbest-scale=1.2, 6.19 & 16.57
+- nbest-rescoring-4-gram,  nbest-scale=1.2, 5.87 & 16.07
+
+Experimental results training on full librispeech (epoch-30-avg-10):
+- 1best. 2.54 & 5.65
+- nbest, nbest-scale=1.2, 2.54 & 5.66
+- nbest-rescoring-LG, nbest-scale=1.2, 2.49 & 5.42
+- nbest-rescoring-3-gram,  nbest-scale=1.2, 2.52 & 5.62
+- nbest-rescoring-4-gram,  nbest-scale=1.2, 2.5 & 5.51
diff --git a/egs/librispeech/ASR/zipformer_mmi/__init__.py b/egs/librispeech/ASR/zipformer_mmi/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py
new file mode 100755
index 000000000..33c0bf199
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/decode.py
@@ -0,0 +1,730 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Liyong Guo,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) 1best
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --decoding-method 1best
+(2) nbest
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest
+(3) nbest-rescoring-LG
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-LG
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-3-gram
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/mmi_decode.py \
+    --epoch 30 \
+    --avg 15 \
+    --exp-dir ./zipformer_mmi/exp \
+    --max-duration 100 \
+    --nbest-scale 1.0 \
+    --decoding-method nbest-rescoring-4-gram
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.lexicon import Lexicon
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_decoding_params() -> AttributeDict:
+    """Parameters for decoding."""
+    params = AttributeDict(
+        {
+            "frame_shift_ms": 10,
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HP: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    G: Optional[k2.Fsa] = None,
+    LG: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+    - key: It indicates the setting used for decoding. For example,
+           if no rescoring is used, the key is the string `no_rescore`.
+           If LM rescoring is used, the key is the string `lm_scale_xxx`,
+           where `xxx` is the value of `lm_scale`. An example key is
+           `lm_scale_0.7`
+    - value: It contains the decoding result. `len(value)` equals to
+             batch size. `value[i]` is the decoding result for the i-th
+             utterance in the given batch.
+
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
+        - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.decoding_method is "nbest-rescoring-LG", it uses nbest rescoring with word-level 3-gram LM.
+        - params.decoding_method is "nbest-rescoring-3-gram", it uses nbest rescoring with token-level 3-gram LM.
+        - params.decoding_method is "nbest-rescoring-4-gram", it uses nbest rescoring with token-level 4-gram LM.
+
+      model:
+        The neural model.
+      HP:
+        The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
+      bpe_model:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      LG:
+        An LM. L is the lexicon, G is a word-level 3-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-LG".
+      G:
+        An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-3-gram"
+        or "nbest-rescoring-4-gram".
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    device = HP.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3, feature.shape
+    feature = feature.to(device)
+
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    method = params.decoding_method
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using HP, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        return {key: hyps}
+
+    assert method in [
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if method == "nbest-rescoring-LG":
+        assert LG is not None
+        LM = LG
+    else:
+        assert G is not None
+        LM = G
+    best_path_dict = nbest_rescore_with_LM(
+        lattice=lattice,
+        LM=LM,
+        num_paths=params.num_paths,
+        lm_scale_list=lm_scale_list,
+        nbest_scale=params.nbest_scale,
+    )
+
+    ans = dict()
+    suffix = f"-nbest-scale-{params.nbest_scale}-{params.num_paths}"
+    for lm_scale_str, best_path in best_path_dict.items():
+        token_ids = get_texts(best_path)
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        ans[lm_scale_str + suffix] = hyps
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HP: k2.Fsa,
+    bpe_model: spm.SentencePieceProcessor,
+    G: Optional[k2.Fsa] = None,
+    LG: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HP:
+        The decoding graph. H is ctc_topo, P is token-level bi-gram LM.
+      bpe_model:
+        The BPE model.
+      LG:
+        An LM. L is the lexicon, G is a word-level 3-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-LG".
+      G:
+        An LM. L is the lexicon, G is a token-level 3-gram or 4-gram LM.
+        It is used when params.decoding_method is "nbest-rescoring-3-gram"
+        or "nbest-rescoring-4-gram".
+
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HP=HP,
+            bpe_model=bpe_model,
+            batch=batch,
+            G=G,
+            LG=LG,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(f, f"{test_set_name}-{key}", results)
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    ), params.decoding_method
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    params.vocab_size = num_classes
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = 0
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    LG = None
+    G = None
+
+    if params.decoding_method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+
+    elif params.decoding_method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = params.decoding_method[-6]
+        assert order in ("3", "4"), (params.decoding_method, order)
+        order = int(order)
+        if not (params.lang_dir / f"{order}gram.pt").is_file():
+            logging.info(f"Loading {order}gram.fst.txt")
+            logging.warning("It may take a few minutes.")
+            with open(params.lang_dir / f"{order}gram.fst.txt") as f:
+                first_token_disambig_id = lexicon.token_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_token_disambig_id] = 0
+                G = k2.Fsa.from_fsas([G]).to(device)
+                # G = k2.remove_epsilon(G)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt")
+        else:
+            logging.info(f"Loading pre-compiled {order}gram.pt")
+            d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        G.lm_scores = G.scores.clone()
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    test_clean_cuts = librispeech.test_clean_cuts()
+    test_other_cuts = librispeech.test_other_cuts()
+
+    test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+    test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+    test_sets = ["test-clean", "test-other"]
+    test_dl = [test_clean_dl, test_other_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HP=HP,
+            bpe_model=bpe_model,
+            G=G,
+            LG=LG,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py b/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py
new file mode 100755
index 000000000..0af7bd367
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/export.py
@@ -0,0 +1,307 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 30 \
+  --avg 9 \
+  --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `zipformer_mmi/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./zipformer_mmi/decode.py \
+        --exp-dir ./zipformer_mmi/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+with the following commands:
+
+    sudo apt-get install git-lfs
+    git lfs install
+    git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+    # You will find the pre-trained model in icefall-asr-librispeech-zipformer-mmi-2022-12-08/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        It will generate a file named cpu_jit.pt
+
+        Check ./jit_pretrained.py for how to use it.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit is True:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script()")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torchscript. Export model.state_dict()")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
new file mode 100755
index 000000000..c9ef16ffa
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
@@ -0,0 +1,391 @@
+#!/usr/bin/env python3
+# Copyright      2021-2022  Xiaomi Corp.   (authors: Fangjun Kuang,
+#                                                    Zengwei)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10 \
+  --jit 1
+
+Usage of this script:
+
+(1) 1best
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(2) nbest
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(3) nbest-rescoring-LG
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-LG \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-3-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/jit_pretrained.py \
+    --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-4-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params
+
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--nn-model-filename",
+        type=str,
+        required=True,
+        help="Path to the torchscript model cpu_jit.pt",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.2,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.1,
+        help="""
+        Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
+        and nbest-rescoring-4-gram.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    logging.info(vars(args))
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    model = torch.jit.load(params.nn_model_filename)
+    model.eval()
+    model.to(device)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(args.bpe_model)
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = 16000
+    opts.mel_opts.num_bins = 80
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {args.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(
+        features,
+        batch_first=True,
+        padding_value=math.log(1e-10),
+    )
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    method = params.method
+    assert method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    )
+    # loading language model for rescoring
+    LM = None
+    if method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+        LM = LG
+    elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = method[-6]
+        assert order in ("3", "4")
+        order = int(order)
+        logging.info(f"Loading pre-compiled {order}gram.pt")
+        d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+        G = k2.Fsa.from_dict(d)
+        G.lm_scores = G.scores.clone()
+        LM = G
+
+    # Encoder forward
+    nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+    else:
+        best_path_dict = nbest_rescore_with_LM(
+            lattice=lattice,
+            LM=LM,
+            num_paths=params.num_paths,
+            lm_scale_list=[params.ngram_lm_scale],
+            nbest_scale=params.nbest_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    # Note: `best_path.aux_labels` contains token IDs, not word IDs
+    # since we are using HP, not HLG here.
+    #
+    # token_ids is a lit-of-list of IDs
+    token_ids = get_texts(best_path)
+    # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+    hyps = bpe_model.decode(token_ids)
+    # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+    hyps = [s.split() for s in hyps]
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/model.py b/egs/librispeech/ASR/zipformer_mmi/model.py
new file mode 100644
index 000000000..4045c8b64
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/model.py
@@ -0,0 +1,75 @@
+# Copyright    2022  Xiaomi Corp.        (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+
+class CTCModel(nn.Module):
+    def __init__(
+        self,
+        encoder: EncoderInterface,
+        encoder_dim: int,
+        vocab_size: int,
+    ):
+        """
+        Args:
+          encoder:
+            It is the transcription network in the paper. Its accepts
+            two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+            It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+            `logit_lens` of shape (N,).
+        """
+        super().__init__()
+        assert isinstance(encoder, EncoderInterface), type(encoder)
+
+        self.encoder = encoder
+
+        self.ctc_output = nn.Sequential(
+            nn.Dropout(p=0.1),
+            nn.Linear(encoder_dim, vocab_size),
+            nn.LogSoftmax(dim=-1),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C).
+          x_lens:
+            A 1-D tensor of shape (N,). It contains the number of frames in `x`
+            before padding.
+        Returns:
+          Return the ctc outputs and encoder output lengths.
+        """
+        assert x.ndim == 3, x.shape
+        assert x_lens.ndim == 1, x_lens.shape
+
+        encoder_out, x_lens = self.encoder(x, x_lens)
+        assert torch.all(x_lens > 0)
+
+        # compute ctc log-probs
+        ctc_output = self.ctc_output(encoder_out)
+
+        return ctc_output, x_lens
diff --git a/egs/librispeech/ASR/zipformer_mmi/optim.py b/egs/librispeech/ASR/zipformer_mmi/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
new file mode 100755
index 000000000..0e7fd0daf
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
@@ -0,0 +1,410 @@
+#!/usr/bin/env python3
+# Copyright      2021-2022  Xiaomi Corp.   (authors: Fangjun Kuang,
+#                                                    Zengwei)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./zipformer_mmi/export.py \
+  --exp-dir ./zipformer_mmi/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) 1best
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method 1best \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(2) nbest
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(3) nbest-rescoring-LG
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-LG \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(4) nbest-rescoring-3-gram
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-3-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+(5) nbest-rescoring-4-gram
+./zipformer_mmi/pretrained.py \
+    --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --nbest-scale 1.2 \
+    --method nbest-rescoring-4-gram \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+
+You can also use `./zipformer_mmi/exp/epoch-xx.pt`.
+
+Note: ./zipformer_mmi/exp/pretrained.pt is generated by
+./zipformer_mmi/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from decode import get_decoding_params
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_ctc_model, get_params
+
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_rescore_with_LM,
+    one_best_decoding,
+)
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method. Use HP as decoding graph, where H is
+        ctc_topo and P is token-level bi-gram lm.
+        Supported values are:
+        - (1) 1best. Extract the best path from the decoding lattice as the
+          decoding result.
+        - (2) nbest. Extract n paths from the decoding lattice; the path
+          with the highest score is the decoding result.
+        - (4) nbest-rescoring-LG. Extract n paths from the decoding lattice,
+          rescore them with an word-level 3-gram LM, the path with the
+          highest score is the decoding result.
+        - (5) nbest-rescoring-3-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 3-gram LM, the path with
+          the highest score is the decoding result.
+        - (6) nbest-rescoring-4-gram. Extract n paths from the decoding
+          lattice, rescore them with an token-level 4-gram LM, the path with
+          the highest score is the decoding result.
+        """,
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=1.2,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.1,
+        help="""
+        Used when method is nbest-rescoring-LG, nbest-rescoring-3-gram,
+        and nbest-rescoring-4-gram.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--hp-scale",
+        type=float,
+        default=1.0,
+        help="""The scale to be applied to `ctc_topo_P.scores`.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    # add decoding params
+    params.update(get_decoding_params())
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    bpe_model = spm.SentencePieceProcessor()
+    bpe_model.load(str(params.lang_dir / "bpe.model"))
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+    HP = mmi_graph_compiler.ctc_topo_P
+    HP.scores *= params.hp_scale
+    if not hasattr(HP, "lm_scores"):
+        HP.lm_scores = HP.scores.clone()
+
+    method = params.method
+    assert method in (
+        "1best",
+        "nbest",
+        "nbest-rescoring-LG",  # word-level 3-gram lm
+        "nbest-rescoring-3-gram",  # token-level 3-gram lm
+        "nbest-rescoring-4-gram",  # token-level 4-gram lm
+    )
+    # loading language model for rescoring
+    LM = None
+    if method == "nbest-rescoring-LG":
+        lg_filename = params.lang_dir / "LG.pt"
+        logging.info(f"Loading {lg_filename}")
+        LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+        LG = k2.Fsa.from_fsas([LG]).to(device)
+        LG.lm_scores = LG.scores.clone()
+        LM = LG
+    elif method in ["nbest-rescoring-3-gram", "nbest-rescoring-4-gram"]:
+        order = method[-6]
+        assert order in ("3", "4")
+        order = int(order)
+        logging.info(f"Loading pre-compiled {order}gram.pt")
+        d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+        G = k2.Fsa.from_dict(d)
+        G.lm_scores = G.scores.clone()
+        LM = G
+
+    # Encoder forward
+    nnet_output, encoder_out_lens = model(x=features, x_lens=feature_lengths)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [
+            [i, 0, feature_lengths[i] // params.subsampling_factor]
+            for i in range(batch_size)
+        ],
+        dtype=torch.int32,
+    )
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=HP,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if method in ["1best", "nbest"]:
+        if method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+    else:
+        best_path_dict = nbest_rescore_with_LM(
+            lattice=lattice,
+            LM=LM,
+            num_paths=params.num_paths,
+            lm_scale_list=[params.ngram_lm_scale],
+            nbest_scale=params.nbest_scale,
+        )
+        best_path = next(iter(best_path_dict.values()))
+
+    # Note: `best_path.aux_labels` contains token IDs, not word IDs
+    # since we are using HP, not HLG here.
+    #
+    # token_ids is a lit-of-list of IDs
+    token_ids = get_texts(best_path)
+    # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+    hyps = bpe_model.decode(token_ids)
+    # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+    hyps = [s.split() for s in hyps]
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/scaling.py b/egs/librispeech/ASR/zipformer_mmi/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py b/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_mmi/test_model.py b/egs/librispeech/ASR/zipformer_mmi/test_model.py
new file mode 100755
index 000000000..7782845f4
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/test_model.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./zipformer_mmi/test_model.py
+"""
+
+import torch
+from train import get_ctc_model, get_params
+
+
+def test_model():
+    params = get_params()
+    params.vocab_size = 500
+    params.num_encoder_layers = "2,4,3,2,4"
+    #  params.feedforward_dims = "1024,1024,1536,1536,1024"
+    params.feedforward_dims = "1024,1024,2048,2048,1024"
+    params.nhead = "8,8,8,8,8"
+    params.encoder_dims = "384,384,384,384,384"
+    params.attention_dims = "192,192,192,192,192"
+    params.encoder_unmasked_dims = "256,256,256,256,256"
+    params.zipformer_downsampling_factors = "1,2,4,8,2"
+    params.cnn_module_kernels = "31,31,31,31,31"
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+    features = torch.randn(2, 100, 80)
+    feature_lengths = torch.full((2,), 100)
+    model(x=features, x_lens=feature_lengths)
+
+
+def main():
+    test_model()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
new file mode 100755
index 000000000..b2784e47c
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -0,0 +1,1198 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir zipformer_mmi/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./zipformer_mmi/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir zipformer_mmi/exp \
+  --full-libri 1 \
+  --max-duration 500
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import CTCModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon, UniqLexicon
+from icefall.mmi import LFMMILoss
+from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="zipformer_mmi/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--use-pruned-intersect",
+        type=str2bool,
+        default=False,
+        help="""Whether to use `intersect_dense_pruned` to get denominator
+        lattice.""",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            # parameters for mmi loss
+            "mmi_beam_size": 6,
+            "den_scale": 1.0,
+            # parameters for mmi loss
+            "ctc_beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_ctc_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+
+    model = CTCModel(
+        encoder=encoder,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute ctc loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_out_lens = model(x=feature, x_lens=feature_lens)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `LFMMILoss.forward()`
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    info = MetricsTracker()
+    if batch_idx_train < warm_step:
+        # Training with ctc loss
+        # Works with a BPE model
+        token_ids = ctc_graph_compiler.texts_to_ids(texts)
+        decoding_graph = ctc_graph_compiler.compile(token_ids)
+        loss = k2.ctc_loss(
+            decoding_graph=decoding_graph,
+            dense_fsa_vec=dense_fsa_vec,
+            output_beam=params.ctc_beam_size,
+            reduction=params.reduction,
+            use_double_scores=params.use_double_scores,
+        )
+        info["ctc_loss"] = loss.detach().cpu().item()
+        info["mmi_loss"] = 0
+    else:
+        # Training with mmi loss
+        loss_fn = LFMMILoss(
+            graph_compiler=mmi_graph_compiler,
+            use_pruned_intersect=params.use_pruned_intersect,
+            den_scale=params.den_scale,
+            beam_size=params.mmi_beam_size,
+        )
+        loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
+        info["ctc_loss"] = 0
+        info["mmi_loss"] = loss.detach().cpu().item()
+
+    assert loss.requires_grad == is_training
+
+    info["frames"] = encoder_out_lens.sum().cpu().item()
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    ctc_graph_compiler=ctc_graph_compiler,
+                    mmi_graph_compiler=mmi_graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(
+                batch, params=params, graph_compiler=mmi_graph_compiler
+            )
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                ctc_graph_compiler=ctc_graph_compiler,
+                mmi_graph_compiler=mmi_graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    if params.full_libri is False:
+        params.valid_interval = 1600
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+    params.vocab_size = num_classes
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    assert "lang_bpe" in str(params.lang_dir)
+    ctc_graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    mmi_graph_compiler = MmiTrainingGraphCompiler(
+        params.lang_dir,
+        uniq_filename="lexicon.txt",
+        device=device,
+        oov="",
+        sos_id=1,
+        eos_id=1,
+    )
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_ctc_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    parameters_names = []
+    parameters_names.append(
+        [name_param_pair[0] for name_param_pair in model.named_parameters()]
+    )
+    optimizer = ScaledAdam(
+        model.parameters(),
+        lr=params.base_lr,
+        clipping_scale=2.0,
+        parameters_names=parameters_names,
+    )
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    librispeech = LibriSpeechAsrDataModule(args)
+
+    # train_cuts = librispeech.train_clean_100_cuts()
+    if params.full_libri:
+        # train_cuts += librispeech.train_clean_360_cuts()
+        # train_cuts += librispeech.train_other_500_cuts()
+        train_cuts = librispeech.train_all_shuf_cuts()
+    else:
+        train_cuts = librispeech.train_clean_100_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 1.0 <= c.duration <= 20.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = librispeech.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = librispeech.dev_clean_cuts()
+    valid_cuts += librispeech.dev_other_cuts()
+    valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            ctc_graph_compiler=ctc_graph_compiler,
+            mmi_graph_compiler=mmi_graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    graph_compiler: MmiTrainingGraphCompiler,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+    y = graph_compiler.texts_to_ids(supervisions["text"])
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    ctc_graph_compiler: BpeCtcTrainingGraphCompiler,
+    mmi_graph_compiler: MmiTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    ctc_graph_compiler=ctc_graph_compiler,
+                    mmi_graph_compiler=mmi_graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(
+                batch, params=params, graph_compiler=mmi_graph_compiler
+            )
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/ASR/zipformer_mmi/zipformer.py b/egs/librispeech/ASR/zipformer_mmi/zipformer.py
new file mode 120000
index 000000000..79b076556
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_mmi/zipformer.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/README.md b/egs/mgb2/ASR/README.md
new file mode 100644
index 000000000..2bc4b000b
--- /dev/null
+++ b/egs/mgb2/ASR/README.md
@@ -0,0 +1,43 @@
+# MGB2
+
+The Multi-Dialect Broadcast News Arabic Speech Recognition (MGB-2):
+The second edition of the Multi-Genre Broadcast (MGB-2) Challenge is
+an evaluation of speech recognition and lightly supervised alignment
+using TV recordings in Arabic. The speech data is broad and multi-genre,
+spanning the whole range of TV output, and represents a challenging task for
+speech technology. In 2016, the challenge featured two new Arabic tracks based
+on TV data from Aljazeera. It was an official challenge at the 2016 IEEE
+Workshop on Spoken Language Technology. The 1,200 hours MGB-2: from Aljazeera
+TV programs have been manually captioned with no timing information.
+QCRI Arabic ASR system has been used to recognize all programs. The ASR output
+was used to align the manual captioning and produce speech segments for
+training speech recognition. More than 20 hours from 2015 programs have been
+transcribed verbatim and manually segmented. This data is split into a
+development set of 10 hours, and a similar evaluation set of 10 hours.
+Both the development and evaluation data have been released in the 2016 MGB
+challenge
+
+Official reference:
+
+Ali, Ahmed, et al. "The MGB-2 challenge: Arabic multi-dialect broadcast media recognition." 
+2016 IEEE Spoken Language Technology Workshop (SLT). IEEE, 2016.
+
+IEEE link: https://ieeexplore.ieee.org/abstract/document/7846277
+
+## Stateless Pruned Transducer Performance Record (after 30 epochs)
+
+|                                    |     dev    |    test    | comment                                  |
+|------------------------------------|------------|------------|------------------------------------------|
+|          greedy search             | 15.52      | 15.28      | --epoch 18, --avg 5, --max-duration 200  |
+| modified beam search               | 13.88      | 13.7       | --epoch 18, --avg 5, --max-duration 200  |
+| fast beam search                   | 14.62      | 14.36      | --epoch 18, --avg 5, --max-duration 200  |
+
+## Conformer-CTC Performance Record (after 40 epochs)
+
+| Decoding method           | dev WER    | test WER |
+|---------------------------|------------|---------|
+| attention-decoder         | 15.62      |  15.01  |
+| whole-lattice-rescoring   | 15.89      |  15.08  |
+
+
+See [RESULTS](/egs/mgb2/ASR/RESULTS.md) for details.
diff --git a/egs/mgb2/ASR/RESULTS.md b/egs/mgb2/ASR/RESULTS.md
new file mode 100644
index 000000000..2a7ea7664
--- /dev/null
+++ b/egs/mgb2/ASR/RESULTS.md
@@ -0,0 +1,236 @@
+# Results
+
+
+### MGB2 all data BPE training results (Stateless Pruned Transducer)
+
+#### 2022-09-07
+
+The WERs are
+
+|                                    |     dev    |    test    | comment                                  |
+|------------------------------------|------------|------------|------------------------------------------|
+|          greedy search             | 15.52      | 15.28      | --epoch 18, --avg 5, --max-duration 200 |
+| modified beam search               | 13.88      | 13.7       | --epoch 18, --avg 5, --max-duration 200 |
+| fast beam search                   | 14.62      | 14.36      | --epoch 18, --avg 5, --max-duration 200|
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+
+  
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 300 \
+  --num-buckets 50
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/YyNv45pfQ0GqWzZ898WOlw/#scalars
+
+The decoding command is:
+```
+epoch=18
+avg=5
+for method in greedy_search modified_beam_search fast_beam_search; do
+  ./pruned_transducer_stateless5/decode.py \
+    --epoch $epoch \
+	--beam-size 10 \
+    --avg $avg \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method $method \
+    --max-sym-per-frame 1 \
+    --num-encoder-layers 12 \
+    --dim-feedforward 2048 \
+    --nhead 8 \
+    --encoder-dim 512 \
+    --decoder-dim 512 \
+    --joiner-dim 512 \
+    --use-averaged-model True
+done
+```
+
+### MGB2 all data BPE training results (Conformer-CTC) (after 40 epochs)
+
+#### 2022-06-04
+
+You can find a pretrained model, training logs, decoding logs, and decoding results at:
+https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06
+
+The best WER, as of 2022-06-04, for the MGB2 test dataset is below
+
+Using whole lattice HLG decoding + n-gram LM rescoring 
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER | 15.62      |  15.01     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.1            | -            |
+
+
+Using n-best (n=0.5) attention decoder rescoring
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER |    15.89   |  15.08     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.01           | 0.5             |
+
+
+To reproduce the above result, use the following commands for training:
+
+# Note: the model was trained on V-100 32GB GPU
+
+```
+cd egs/mgb2/ASR
+. ./path.sh
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1"
+./conformer_ctc/train.py \
+  --lang-dir data/lang_bpe_5000 \
+  --att-rate 0.8 \
+  --lr-factor 10 \
+  --max-duration  \
+  --concatenate-cuts 0 \
+  --world-size 2 \
+  --bucketing-sampler 1 \
+  --max-duration 100 \ 
+  --start-epoch 0 \
+  --num-epochs 40
+  
+```
+
+and the following command for nbest decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method attention-decoder \
+  --nbest-scale 0.5
+```
+
+and the following command for whole-lattice decoding
+
+```
+./conformer_ctc/decode.py \
+  --epoch 40 \
+  --avg 5 \
+  --exp-dir conformer_ctc/exp_5000_att0.8 \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --method  whole-lattice-rescoring
+```
+
+
+The tensorboard log for training is available at
+https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars
+
+
+### MGB2 100h BPE training results (Conformer-CTC) (after 33 epochs)
+
+#### 2022-06-04
+
+The best WER, as of 2022-06-04, for the MGB2 test dataset is below
+
+Using whole lattice HLG decoding + n-gram LM rescoring 
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER | 25.32      |  23.53     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.1            | -            |
+
+
+Using n-best (n=0.5) HLG decoding + n-gram LM rescoring + attention decoder rescoring:
+
+|     | dev        | test       |
+|-----|------------|------------|
+| WER |    27.87   |  26.12     |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.01           | 0.3             |
+
+
+To reproduce the above result, use the following commands for training:
+
+# Note: the model was trained on V-100 32GB GPU
+
+```
+cd egs/mgb2/ASR
+. ./path.sh
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1"
+./conformer_ctc/train.py \
+  --lang-dir data/lang_bpe_5000 \
+  --att-rate 0.8 \
+  --lr-factor 10 \
+  --max-duration  \
+  --concatenate-cuts 0 \
+  --world-size 2 \
+  --bucketing-sampler 1 \
+  --max-duration 100 \ 
+  --start-epoch 0 \
+  --num-epochs 40
+  
+```
+
+and the following command for nbest decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method attention-decoder \
+  --nbest-scale 0.5
+```
+
+and the following command for whole-lattice decoding
+
+```
+./conformer_ctc/decode.py \
+  --lang-dir data/lang_bpe_5000 \
+  --max-duration 30 \
+  --concatenate-cuts 0 \
+  --bucketing-sampler 1 \
+  --num-paths 1000 \
+  --epoch 40 \
+  --avg 5 \
+  --method  whole-lattice-rescoring
+```
+
+The tensorboard log for training is available at
+
+
+
+
+
diff --git a/egs/mgb2/ASR/conformer_ctc/__init__.py b/egs/mgb2/ASR/conformer_ctc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/conformer_ctc/ali.py b/egs/mgb2/ASR/conformer_ctc/ali.py
new file mode 100755
index 000000000..aea962dcd
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/ali.py
@@ -0,0 +1,395 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+    ./conformer_ctc/ali.py \
+            --exp-dir ./conformer_ctc/exp \
+            --lang-dir ./data/lang_bpe_500 \
+            --epoch 20 \
+            --avg 10 \
+            --max-duration 300 \
+            --dataset train-clean-100 \
+            --out-dir data/ali
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse import CutSet
+from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import one_best_decoding
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    encode_supervisions,
+    get_alignments,
+    setup_logger,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=34,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=20,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--out-dir",
+        type=str,
+        required=True,
+        help="""Output directory.
+        It contains 3 generated files:
+
+        - labels_xxx.h5
+        - aux_labels_xxx.h5
+        - cuts_xxx.json.gz
+
+        where xxx is the value of `--dataset`. For instance, if
+        `--dataset` is `train-clean-100`, it will contain 3 files:
+
+        - `labels_train-clean-100.h5`
+        - `aux_labels_train-clean-100.h5`
+        - `cuts_train-clean-100.json.gz`
+
+        Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
+        alignment. The difference is that labels_xxx.h5 contains repeats.
+        """,
+    )
+
+    parser.add_argument(
+        "--dataset",
+        type=str,
+        required=True,
+        help="""The name of the dataset to compute alignments for.
+        Possible values are:
+            - test-clean.
+            - test-other
+            - train-clean-100
+            - train-clean-360
+            - train-other-500
+            - dev-clean
+            - dev-other
+        """,
+    )
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "lm_dir": Path("data/lm"),
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "subsampling_factor": 4,
+            # Set it to 0 since attention decoder
+            # is not used for computing alignments
+            "num_decoder_layers": 0,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "output_beam": 10,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def compute_alignments(
+    model: torch.nn.Module,
+    dl: torch.utils.data.DataLoader,
+    labels_writer: FeaturesWriter,
+    aux_labels_writer: FeaturesWriter,
+    params: AttributeDict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+) -> CutSet:
+    """Compute the framewise alignments of a dataset.
+
+    Args:
+      model:
+        The neural network model.
+      dl:
+        Dataloader containing the dataset.
+      params:
+        Parameters for computing alignments.
+      graph_compiler:
+        It converts token IDs to decoding graphs.
+    Returns:
+      Return a CutSet. Each cut has two custom fields: labels_alignment
+      and aux_labels_alignment, containing framewise alignments information.
+      Both are of type `lhotse.array.TemporalArray`. The difference between
+      the two alignments is that `labels_alignment` contain repeats.
+    """
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+    num_cuts = 0
+
+    device = graph_compiler.device
+    cuts = []
+    for batch_idx, batch in enumerate(dl):
+        feature = batch["inputs"]
+
+        # at entry, feature is [N, T, C]
+        assert feature.ndim == 3
+        feature = feature.to(device)
+
+        supervisions = batch["supervisions"]
+        cut_list = supervisions["cut"]
+
+        for cut in cut_list:
+            assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
+
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is [N, T, C]
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+        # we need also to sort cut_ids as encode_supervisions()
+        # reorders "texts".
+        # In general, new2old is an identity map since lhotse sorts the returned
+        # cuts by duration in descending order
+        new2old = supervision_segments[:, 0].tolist()
+
+        cut_list = [cut_list[i] for i in new2old]
+
+        token_ids = graph_compiler.texts_to_ids(texts)
+        decoding_graph = graph_compiler.compile(token_ids)
+
+        dense_fsa_vec = k2.DenseFsaVec(
+            nnet_output,
+            supervision_segments,
+            allow_truncate=params.subsampling_factor - 1,
+        )
+
+        lattice = k2.intersect_dense(
+            decoding_graph,
+            dense_fsa_vec,
+            params.output_beam,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice,
+            use_double_scores=params.use_double_scores,
+        )
+
+        labels_ali = get_alignments(best_path, kind="labels")
+        aux_labels_ali = get_alignments(best_path, kind="aux_labels")
+        assert len(labels_ali) == len(aux_labels_ali) == len(cut_list)
+        for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali):
+            cut.labels_alignment = labels_writer.store_array(
+                key=cut.id,
+                value=np.asarray(labels, dtype=np.int32),
+                # frame shift is 0.01s, subsampling_factor is 4
+                frame_shift=0.04,
+                temporal_dim=0,
+                start=0,
+            )
+            cut.aux_labels_alignment = aux_labels_writer.store_array(
+                key=cut.id,
+                value=np.asarray(aux_labels, dtype=np.int32),
+                # frame shift is 0.01s, subsampling_factor is 4
+                frame_shift=0.04,
+                temporal_dim=0,
+                start=0,
+            )
+
+        cuts += cut_list
+
+        num_cuts += len(cut_list)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+    return CutSet.from_cuts(cuts)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    LibriSpeechAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+
+    args.enable_spec_aug = False
+    args.enable_musan = False
+    args.return_cuts = True
+    args.concatenate_cuts = False
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-ali")
+
+    logging.info(f"Computing alignments for {params.dataset} - started")
+    logging.info(params)
+
+    out_dir = Path(params.out_dir)
+    out_dir.mkdir(exist_ok=True)
+
+    out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
+    out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
+    out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
+
+    for f in (
+        out_labels_ali_filename,
+        out_aux_labels_ali_filename,
+        out_manifest_filename,
+    ):
+        if f.exists():
+            logging.info(f"{f} exists - skipping")
+            return
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+    model.to(device)
+
+    if params.avg == 1:
+        load_checkpoint(
+            f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
+        )
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+
+    model.eval()
+
+    librispeech = LibriSpeechAsrDataModule(args)
+    if params.dataset == "test-clean":
+        test_clean_cuts = librispeech.test_clean_cuts()
+        dl = librispeech.test_dataloaders(test_clean_cuts)
+    elif params.dataset == "test-other":
+        test_other_cuts = librispeech.test_other_cuts()
+        dl = librispeech.test_dataloaders(test_other_cuts)
+    elif params.dataset == "train-clean-100":
+        train_clean_100_cuts = librispeech.train_clean_100_cuts()
+        dl = librispeech.train_dataloaders(train_clean_100_cuts)
+    elif params.dataset == "train-clean-360":
+        train_clean_360_cuts = librispeech.train_clean_360_cuts()
+        dl = librispeech.train_dataloaders(train_clean_360_cuts)
+    elif params.dataset == "train-other-500":
+        train_other_500_cuts = librispeech.train_other_500_cuts()
+        dl = librispeech.train_dataloaders(train_other_500_cuts)
+    elif params.dataset == "dev-clean":
+        dev_clean_cuts = librispeech.dev_clean_cuts()
+        dl = librispeech.valid_dataloaders(dev_clean_cuts)
+    else:
+        assert params.dataset == "dev-other", f"{params.dataset}"
+        dev_other_cuts = librispeech.dev_other_cuts()
+        dl = librispeech.valid_dataloaders(dev_other_cuts)
+
+    logging.info(f"Processing {params.dataset}")
+    with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer:
+        with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer:
+            cut_set = compute_alignments(
+                model=model,
+                dl=dl,
+                labels_writer=labels_writer,
+                aux_labels_writer=aux_labels_writer,
+                params=params,
+                graph_compiler=graph_compiler,
+            )
+
+    cut_set.to_file(out_manifest_filename)
+
+    logging.info(
+        f"For dataset {params.dataset}, its alignments with repeats are "
+        f"saved to {out_labels_ali_filename}, the alignments without repeats "
+        f"are saved to {out_aux_labels_ali_filename}, and the cut manifest "
+        f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
+    )
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
new file mode 100644
index 000000000..8242e986d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
@@ -0,0 +1,372 @@
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+    CutConcatenate,
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SingleCutSampler,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class MGB2AsrDataModule:
+
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders
+
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/fbank"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=200.0,
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
+        )
+        group.add_argument(
+            "--bucketing-sampler",
+            type=str2bool,
+            default=True,
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=30,
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
+        )
+        group.add_argument(
+            "--drop-last",
+            type=str2bool,
+            default=True,
+            help="Whether to drop last batch. Used by sampler.",
+        )
+        group.add_argument(
+            "--return-cuts",
+            type=str2bool,
+            default=True,
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=1,
+            help="The number of training dataloader workers that "
+            "collect the batches.",
+        )
+
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
+        )
+
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            logging.info("About to get Musan cuts")
+            cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
+
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                f"Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            # Set the value of num_frame_masks according to Lhotse's version.
+            # In different Lhotse's versions, the default of num_frame_masks is
+            # different.
+            num_frame_masks = 10
+            num_frame_masks_parameter = inspect.signature(
+                SpecAugment.__init__
+            ).parameters["num_frame_masks"]
+            if num_frame_masks_parameter.default == 1:
+                num_frame_masks = 2
+            logging.info(f"Num frame mask: {num_frame_masks}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=num_frame_masks,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        train = K2SpeechRecognitionDataset(
+            cut_transforms=transforms,
+            input_transforms=input_transforms,
+            return_cuts=self.args.return_cuts,
+        )
+
+        if self.args.on_the_fly_feats:
+            # NOTE: the PerturbSpeed transform should be added only if we
+            # remove it from data prep stage.
+            # Add on-the-fly speed perturbation; since originally it would
+            # have increased epoch size by 3, we will apply prob 2/3 and use
+            # 3x more epochs.
+            # Speed perturbation probably should come first before
+            # concatenation, but in principle the transforms order doesn't have
+            # to be strict (e.g. could be randomized)
+            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
+            # Drop feats to be on the safe side.
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+
+        if self.args.bucketing_sampler:
+            logging.info("Using DynamicBucketingSampler.")
+            train_sampler = DynamicBucketingSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+                num_buckets=self.args.num_buckets,
+                drop_last=self.args.drop_last,
+            )
+        else:
+            logging.info("Using SingleCutSampler.")
+            train_sampler = SingleCutSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+            )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else PrecomputedFeatures(),
+            return_cuts=self.args.return_cuts,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts, max_duration=self.args.max_duration, shuffle=False
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    @lru_cache()
+    def train_cuts(self) -> CutSet:
+        logging.info("About to get train cuts")
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
+
+    @lru_cache()
+    def dev_cuts(self) -> CutSet:
+        logging.info("About to get dev cuts")
+
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
+
+    @lru_cache()
+    def test_cuts(self) -> CutSet:
+        logging.info("About to get test cuts")
+
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
diff --git a/egs/mgb2/ASR/conformer_ctc/compile_hlg.py b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/conformer.py b/egs/mgb2/ASR/conformer_ctc/conformer.py
new file mode 120000
index 000000000..d1f4209d7
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/conformer.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
similarity index 100%
rename from egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py
rename to egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py
diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py
new file mode 100755
index 000000000..f771d7f1e
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/decode.py
@@ -0,0 +1,695 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import pdb
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    setup_logger,
+    store_transcripts,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=50,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=5,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+              model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+              It needs neither a lexicon nor an n-gram LM.
+            - (1) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (2) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (5) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (6) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=20,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--lm-dir",
+        type=str,
+        default="data/lm",
+        help="""The LM dir.
+        It should contain either G_4_gram.pt or G_4_gram.fst.txt
+        """,
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "num_decoder_layers": 6,
+            # parameters for decoding
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"] // params.subsampling_factor,
+            supervisions["num_frames"] // params.subsampling_factor,
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        hyps = [s.split() for s in hyps]
+        key = "ctc-decoding"
+        return {key: hyps}
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.method in ["1best", "nbest"]:
+        if params.method == "1best":
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+            key = "no_rescore"
+        else:
+            best_path = nbest_decoding(
+                lattice=lattice,
+                num_paths=params.num_paths,
+                use_double_scores=params.use_double_scores,
+                nbest_scale=params.nbest_scale,
+            )
+            key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [[word_table[i] for i in ids] for ids in hyps]
+        return {key: hyps}
+
+    assert params.method in [
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "attention-decoder":
+        # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
+        rescored_lattice = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=None,
+        )
+        # TODO: pass `lattice` instead of `rescored_lattice` to
+        # `rescore_with_attention_decoder`
+
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=rescored_lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            nbest_scale=params.nbest_scale,
+        )
+    else:
+        assert False, f"Unsupported decoding method: {params.method}"
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [[word_table[i] for i in ids] for ids in hyps]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        # pdb.set_trace()
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        if hyps_dict is not None:
+            for lm_scale, hyps in hyps_dict.items():
+                this_batch = []
+                assert len(hyps) == len(texts)
+                for hyp_words, ref_text in zip(hyps, texts):
+                    ref_words = ref_text.split()
+                    this_batch.append((ref_words, hyp_words))
+
+                results[lm_scale].extend(this_batch)
+        else:
+            assert len(results) > 0, "It should not decode to empty in the first batch!"
+            this_batch = []
+            hyp_words = []
+            for ref_text in texts:
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            for lm_scale in results.keys():
+                results[lm_scale].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_dir = Path(args.lm_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    if params.method == "ctc-decoding":
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in (
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ):
+        if not (params.lm_dir / "G_4_gram.pt").is_file():
+            logging.info("Loading G_4_gram.fst.txt")
+            logging.warning("It may take 8 minutes.")
+            with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+                first_word_disambig_id = lexicon.word_table["#0"]
+
+                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                # G.aux_labels is not needed in later computations, so
+                # remove it here.
+                del G.aux_labels
+                # CAUTION: The following line is crucial.
+                # Arcs entering the back-off state have label equal to #0.
+                # We have to change it to 0 here.
+                G.labels[G.labels >= first_word_disambig_id] = 0
+                # See https://github.com/k2-fsa/k2/issues/874
+                # for why we need to set G.properties to None
+                G.__dict__["_properties"] = None
+                G = k2.Fsa.from_fsas([G]).to(device)
+                G = k2.arc_sort(G)
+                # Save a dummy value so that it can be loaded in C++.
+                # See https://github.com/pytorch/pytorch/issues/67902
+                # for why we need to do this.
+                G.dummy = 1
+
+                torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+        else:
+            logging.info("Loading pre-compiled G_4_gram.pt")
+            d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+            G = k2.Fsa.from_dict(d)
+
+        if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    if params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(average_checkpoints(filenames, device=device))
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    test_cuts = MGB2.test_cuts()
+    dev_cuts = MGB2.dev_cuts()
+
+    test_dl = MGB2.test_dataloaders(test_cuts)
+    dev_dl = MGB2.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_all_dl = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_all_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/download_lm.py b/egs/mgb2/ASR/conformer_ctc/download_lm.py
new file mode 120000
index 000000000..c9668bd2d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/download_lm.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/download_lm.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/export.py b/egs/mgb2/ASR/conformer_ctc/export.py
new file mode 120000
index 000000000..60e314d9d
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/export.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py
new file mode 100755
index 000000000..d30ca98d8
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py
@@ -0,0 +1,430 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from conformer import Conformer
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall.decode import (
+    get_lattice,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_whole_lattice,
+)
+from icefall.utils import AttributeDict, get_texts
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--words-file",
+        type=str,
+        help="""Path to words.txt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--HLG",
+        type=str,
+        help="""Path to HLG.pt.
+        Used only when method is not ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.
+        Used only when method is ctc-decoding.
+        """,
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="1best",
+        help="""Decoding method.
+        Possible values are:
+        (0) ctc-decoding - Use CTC decoding. It uses a sentence
+            piece model, i.e., lang_dir/bpe.model, to convert
+            word pieces to words. It needs neither a lexicon
+            nor an n-gram LM.
+        (1) 1best - Use the best path as decoding output. Only
+            the transformer encoder output is used for decoding.
+            We call it HLG decoding.
+        (2) whole-lattice-rescoring - Use an LM to rescore the
+            decoding lattice and then use 1best to decode the
+            rescored lattice.
+            We call it HLG decoding + n-gram LM rescoring.
+        (3) attention-decoder - Extract n paths from the rescored
+            lattice and use the transformer attention decoder for
+            rescoring.
+            We call it HLG decoding + n-gram LM rescoring + attention
+            decoder rescoring.
+        """,
+    )
+
+    parser.add_argument(
+        "--G",
+        type=str,
+        help="""An LM for rescoring.
+        Used only when method is
+        whole-lattice-rescoring or attention-decoder.
+        It's usually a 4-gram LM.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the size of n-best list.""",
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=1.3,
+        help="""
+        Used only when method is whole-lattice-rescoring and attention-decoder.
+        It specifies the scale for n-gram LM scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--attention-decoder-scale",
+        type=float,
+        default=1.2,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for attention decoder scores.
+        (Note: You need to tune it on a dataset.)
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies the scale for lattice.scores when
+        extracting n-best lists. A smaller value results in
+        more unique number of paths with the risk of missing
+        the best path.
+        """,
+    )
+
+    parser.add_argument(
+        "--sos-id",
+        type=int,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the SOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "--num-classes",
+        type=int,
+        default=500,
+        help="""
+        Vocab size in the BPE model.
+        """,
+    )
+
+    parser.add_argument(
+        "--eos-id",
+        type=int,
+        default=1,
+        help="""
+        Used only when method is attention-decoder.
+        It specifies ID for the EOS token.
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    params = AttributeDict(
+        {
+            "sample_rate": 16000,
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "vgg_frontend": False,
+            "use_feat_batchnorm": True,
+            "feature_dim": 80,
+            "nhead": 8,
+            "attention_dim": 512,
+            "num_decoder_layers": 6,
+            # parameters for decoding
+            "search_beam": 20,
+            "output_beam": 8,
+            "min_active_states": 30,
+            "max_active_states": 10000,
+            "use_double_scores": True,
+        }
+    )
+    return params
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+    if args.method != "attention-decoder":
+        # to save memory as the attention decoder
+        # will not be used
+        params.num_decoder_layers = 0
+
+    params.update(vars(args))
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=params.num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=params.vgg_frontend,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    # Note: We don't use key padding mask for attention during decoding
+    with torch.no_grad():
+        nnet_output, memory, memory_key_padding_mask = model(features)
+
+    batch_size = nnet_output.shape[0]
+    supervision_segments = torch.tensor(
+        [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
+        dtype=torch.int32,
+    )
+
+    if params.method == "ctc-decoding":
+        logging.info("Use CTC decoding")
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(params.bpe_model)
+        max_token_id = params.num_classes - 1
+
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=H,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        token_ids = get_texts(best_path)
+        hyps = bpe_model.decode(token_ids)
+        hyps = [s.split() for s in hyps]
+    elif params.method in [
+        "1best",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]:
+        logging.info(f"Loading HLG from {params.HLG}")
+        HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+        HLG = HLG.to(device)
+        if not hasattr(HLG, "lm_scores"):
+            # For whole-lattice-rescoring and attention-decoder
+            HLG.lm_scores = HLG.scores.clone()
+
+        if params.method in [
+            "whole-lattice-rescoring",
+            "attention-decoder",
+        ]:
+            logging.info(f"Loading G from {params.G}")
+            G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = G.to(device)
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G.lm_scores = G.scores.clone()
+
+        lattice = get_lattice(
+            nnet_output=nnet_output,
+            decoding_graph=HLG,
+            supervision_segments=supervision_segments,
+            search_beam=params.search_beam,
+            output_beam=params.output_beam,
+            min_active_states=params.min_active_states,
+            max_active_states=params.max_active_states,
+            subsampling_factor=params.subsampling_factor,
+        )
+
+        if params.method == "1best":
+            logging.info("Use HLG decoding")
+            best_path = one_best_decoding(
+                lattice=lattice, use_double_scores=params.use_double_scores
+            )
+        elif params.method == "whole-lattice-rescoring":
+            logging.info("Use HLG decoding + LM rescoring")
+            best_path_dict = rescore_with_whole_lattice(
+                lattice=lattice,
+                G_with_epsilon_loops=G,
+                lm_scale_list=[params.ngram_lm_scale],
+            )
+            best_path = next(iter(best_path_dict.values()))
+        elif params.method == "attention-decoder":
+            logging.info("Use HLG + LM rescoring + attention decoder rescoring")
+            rescored_lattice = rescore_with_whole_lattice(
+                lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
+            )
+            best_path_dict = rescore_with_attention_decoder(
+                lattice=rescored_lattice,
+                num_paths=params.num_paths,
+                model=model,
+                memory=memory,
+                memory_key_padding_mask=memory_key_padding_mask,
+                sos_id=params.sos_id,
+                eos_id=params.eos_id,
+                nbest_scale=params.nbest_scale,
+                ngram_lm_scale=params.ngram_lm_scale,
+                attention_scale=params.attention_decoder_scale,
+            )
+            best_path = next(iter(best_path_dict.values()))
+
+        hyps = get_texts(best_path)
+        word_sym_table = k2.SymbolTable.from_file(params.words_file)
+        hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/subsampling.py b/egs/mgb2/ASR/conformer_ctc/subsampling.py
new file mode 120000
index 000000000..16354dc73
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/subsampling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
new file mode 120000
index 000000000..04b959ecf
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_subsampling.py b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py
new file mode 120000
index 000000000..98c3be3e6
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_subsampling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/test_transformer.py b/egs/mgb2/ASR/conformer_ctc/test_transformer.py
new file mode 120000
index 000000000..8b0990ec6
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/test_transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/test_transformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/conformer_ctc/train.py b/egs/mgb2/ASR/conformer_ctc/train.py
new file mode 100755
index 000000000..08ffee210
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/train.py
@@ -0,0 +1,766 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=50,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=0,
+        help="""Resume training from from this epoch.
+        If it is positive, it will load checkpoint from
+        conformer_ctc/exp/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--att-rate",
+        type=float,
+        default=0.8,
+        help="""The attention rate.
+        The total loss is (1 -  att_rate) * ctc_loss + att_rate * att_loss
+        """,
+    )
+
+    parser.add_argument(
+        "--num-decoder-layers",
+        type=int,
+        default=6,
+        help="""Number of decoder layer of transformer decoder.
+        Setting this to 0 will not create the decoder at all (pure CTC model)
+        """,
+    )
+
+    parser.add_argument(
+        "--lr-factor",
+        type=float,
+        default=5.0,
+        help="The lr_factor for Noam optimizer",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - use_feat_batchnorm: Normalization for the input features, can be a
+                              boolean indicating whether to do batch
+                              normalization, or a float which means just scaling
+                              the input features with this float value.
+                              If given a float value, we will remove batchnorm
+                              layer in `ConvolutionModule` as well.
+
+        - attention_dim: Hidden dim for multi-head attention model.
+
+        - head: Number of heads of multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - beam_size: It is used in k2.ctc_loss
+
+        - reduction: It is used in k2.ctc_loss
+
+        - use_double_scores: It is used in k2.ctc_loss
+
+        - weight_decay:  The weight_decay for the optimizer.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            "use_feat_batchnorm": True,
+            "attention_dim": 512,
+            "nhead": 8,
+            "num_decoder_layers": 6,
+            # parameters for loss
+            "beam_size": 10,
+            "reduction": "sum",
+            "use_double_scores": True,
+            # parameters for Noam
+            "weight_decay": 1e-6,
+            "warm_step": 80000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+    """Load checkpoint from file.
+
+    If params.start_epoch is positive, it will load the checkpoint from
+    `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+    Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+    it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The learning rate scheduler we are using.
+    Returns:
+      Return None.
+    """
+    if params.start_epoch <= 0:
+        return
+
+    filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    batch: dict,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+    """
+    device = graph_compiler.device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+        # nnet_output is (N, T, C)
+
+    # NOTE: We need `encode_supervisions` to sort sequences with
+    # different duration in decreasing order, required by
+    # `k2.intersect_dense` called in `k2.ctc_loss`
+    supervision_segments, texts = encode_supervisions(
+        supervisions, subsampling_factor=params.subsampling_factor
+    )
+
+    token_ids = graph_compiler.texts_to_ids(texts)
+
+    decoding_graph = graph_compiler.compile(token_ids)
+
+    dense_fsa_vec = k2.DenseFsaVec(
+        nnet_output,
+        supervision_segments,
+        allow_truncate=params.subsampling_factor - 1,
+    )
+
+    ctc_loss = k2.ctc_loss(
+        decoding_graph=decoding_graph,
+        dense_fsa_vec=dense_fsa_vec,
+        output_beam=params.beam_size,
+        reduction="none",
+        use_double_scores=params.use_double_scores,
+    )
+    # filter inf from ctc_loss
+    ctc_loss = torch.sum(
+        torch.where(
+            ctc_loss != float("inf"),
+            ctc_loss,
+            torch.tensor(0, dtype=torch.float32).to(device),
+        )
+    )
+
+    if params.att_rate != 0.0:
+        with torch.set_grad_enabled(is_training):
+            mmodel = model.module if hasattr(model, "module") else model
+            # Note: We need to generate an unsorted version of token_ids
+            # `encode_supervisions()` called above sorts text, but
+            # encoder_memory and memory_mask are not sorted, so we
+            # use an unsorted version `supervisions["text"]` to regenerate
+            # the token_ids
+            #
+            # See https://github.com/k2-fsa/icefall/issues/97
+            # for more details
+            unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+
+            att_loss = mmodel.decoder_forward(
+                encoder_memory,
+                memory_mask,
+                token_ids=unsorted_token_ids,
+                sos_id=graph_compiler.sos_id,
+                eos_id=graph_compiler.eos_id,
+            )
+        loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+    else:
+        loss = ctc_loss
+        att_loss = torch.tensor([0])
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    info["frames"] = supervision_segments[:, 2].sum().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+    if params.att_rate != 0.0:
+        info["att_loss"] = att_loss.detach().cpu().item()
+
+    info["loss"] = loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            batch=batch,
+            graph_compiler=graph_compiler,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+    """
+
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
+            params.batch_idx_train += 1
+            batch_size = len(batch["supervisions"]["text"])
+
+            loss, loss_info = compute_loss(
+                params=params,
+                model=model,
+                batch=batch,
+                graph_compiler=graph_compiler,
+                is_training=True,
+            )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+            # if tot_loss is None:
+            #     logging.warning("Batch mismatch. Skipping ...")
+            #     del batch
+            #     del tot_loss
+            #     continue;
+            # elif tot_loss.isinf() or tot_loss.isnan():
+            #     logging.warning("NaN or Inf loss. Skipping ...")
+            #     del batch
+            #     del tot_loss
+            #     continue;
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+
+            optimizer.zero_grad()
+            loss.backward()
+            clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+
+            if batch_idx % params.log_interval == 0:
+                logging.info(
+                    f"Epoch {params.cur_epoch}, "
+                    f"batch {batch_idx}, loss[{loss_info}], "
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}"
+                )
+
+            if batch_idx % params.log_interval == 0:
+
+                if tb_writer is not None:
+                    loss_info.write_summary(
+                        tb_writer, "train/current_", params.batch_idx_train
+                    )
+                    tot_loss.write_summary(
+                        tb_writer, "train/tot_", params.batch_idx_train
+                    )
+
+            if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+                logging.info("Computing validation loss")
+                valid_info = compute_validation_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    valid_dl=valid_dl,
+                    world_size=world_size,
+                )
+                model.train()
+                logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+                if tb_writer is not None:
+                    valid_info.write_summary(
+                        tb_writer, "train/valid_", params.batch_idx_train
+                    )
+        else:
+            logging.warning(
+                f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..."
+            )
+            continue
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(42)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        nhead=params.nhead,
+        d_model=params.attention_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        num_decoder_layers=params.num_decoder_layers,
+        vgg_frontend=False,
+        use_feat_batchnorm=params.use_feat_batchnorm,
+    )
+
+    checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+    model.to(device)
+    if world_size > 1:
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Noam(
+        model.parameters(),
+        model_size=params.attention_dim,
+        factor=params.lr_factor,
+        warm_step=params.warm_step,
+        weight_decay=params.weight_decay,
+    )
+
+    if checkpoints:
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    train_cuts = MGB2.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 0.5 <= c.duration <= 30.0
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+    train_dl = MGB2.train_dataloaders(train_cuts)
+
+    valid_cuts = MGB2.dev_cuts()
+    valid_dl = MGB2.test_dataloaders(valid_cuts)
+
+    scan_pessimistic_batches_for_oom(
+        model=model,
+        train_dl=train_dl,
+        optimizer=optimizer,
+        graph_compiler=graph_compiler,
+        params=params,
+    )
+
+    for epoch in range(params.start_epoch, params.num_epochs):
+        train_dl.sampler.set_epoch(epoch)
+
+        cur_lr = optimizer._rate
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        if rank == 0:
+            logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            tb_writer=tb_writer,
+            world_size=world_size,
+        )
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: nn.Module,
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            optimizer.zero_grad()
+            loss, _ = compute_loss(
+                params=params,
+                model=model,
+                batch=batch,
+                graph_compiler=graph_compiler,
+                is_training=True,
+            )
+            loss.backward()
+            clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+        except RuntimeError as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            raise
+
+
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/conformer_ctc/transformer.py b/egs/mgb2/ASR/conformer_ctc/transformer.py
new file mode 120000
index 000000000..1c3f43fcf
--- /dev/null
+++ b/egs/mgb2/ASR/conformer_ctc/transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/transformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/__init__.py b/egs/mgb2/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/local/compile_hlg.py b/egs/mgb2/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/mgb2/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/compute_fbank_mgb2.py b/egs/mgb2/ASR/local/compute_fbank_mgb2.py
new file mode 100755
index 000000000..6cae69e41
--- /dev/null
+++ b/egs/mgb2/ASR/local/compute_fbank_mgb2.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the MGB2 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_mgb2():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "train",
+        "test",
+        "dev",
+    )
+    manifests = read_manifests_if_cached(
+        prefix="mgb2", dataset_parts=dataset_parts, output_dir=src_dir
+    )
+    assert manifests is not None
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        for partition, m in manifests.items():
+            if (output_dir / f"cuts_{partition}.json.gz").is_file():
+                logging.info(f"{partition} already exists - skipping.")
+                continue
+            logging.info(f"Processing {partition}")
+            cut_set = CutSet.from_manifests(
+                recordings=m["recordings"],
+                supervisions=m["supervisions"],
+            )
+            if "train" in partition:
+                cut_set = (
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                )
+            cut_set = cut_set.compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/feats_{partition}",
+                # when an executor is specified, make more partitions
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+            logging.info("About to split cuts into smaller chunks.")
+            cut_set = cut_set.trim_to_supervisions(
+                keep_overlapping=False, min_duration=None
+            )
+            cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    compute_fbank_mgb2()
diff --git a/egs/mgb2/ASR/local/compute_fbank_musan.py b/egs/mgb2/ASR/local/compute_fbank_musan.py
new file mode 100755
index 000000000..5d0d69a13
--- /dev/null
+++ b/egs/mgb2/ASR/local/compute_fbank_musan.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+    ChunkedLilcomHdf5Writer,
+    CutSet,
+    Fbank,
+    FbankConfig,
+    LilcomChunkyWriter,
+    combine,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    dataset_parts = (
+        "music",
+        "speech",
+        "noise",
+    )
+    prefix = "musan"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        prefix=prefix,
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        suffix=suffix,
+    )
+    assert manifests is not None
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+    )
+
+    musan_cuts_path = output_dir / "cuts_musan.jsonl.gz"
+
+    if musan_cuts_path.is_file():
+        logging.info(f"{musan_cuts_path} already exists - skipping")
+        return
+
+    logging.info("Extracting features for Musan")
+
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        # create chunks of Musan with duration 5 - 10 seconds
+        musan_cuts = (
+            CutSet.from_manifests(
+                recordings=combine(part["recordings"] for part in manifests.values())
+            )
+            .cut_into_windows(10.0)
+            .filter(lambda c: c.duration > 5)
+            .compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/feats_musan",
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+        )
+        musan_cuts.to_file(musan_cuts_path)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    compute_fbank_musan()
diff --git a/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
new file mode 100755
index 000000000..a8d5117c9
--- /dev/null
+++ b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+"""
+Convert a transcript file containing words to a corpus file containing tokens
+for LM training with the help of a lexicon.
+
+If the lexicon contains phones, the resulting LM will be a phone LM; If the
+lexicon contains word pieces, the resulting LM will be a word piece LM.
+
+If a word has multiple pronunciations, the one that appears first in the lexicon
+is kept; others are removed.
+
+If the input transcript is:
+
+    hello zoo world hello
+    world zoo
+    foo zoo world hellO
+
+and if the lexicon is
+
+     SPN
+    hello h e l l o 2
+    hello h e l l o
+    world w o r l d
+    zoo z o o
+
+Then the output is
+
+    h e l l o 2 z o o w o r l d h e l l o 2
+    w o r l d z o o
+    SPN z o o w o r l d SPN
+"""
+
+import argparse
+from pathlib import Path
+from typing import Dict, List
+
+from generate_unique_lexicon import filter_multiple_pronunications
+
+from icefall.lexicon import read_lexicon
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--transcript",
+        type=str,
+        help="The input transcript file."
+        "We assume that the transcript file consists of "
+        "lines. Each line consists of space separated words.",
+    )
+    parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
+    parser.add_argument("--oov", type=str, default="", help="The OOV word.")
+
+    return parser.parse_args()
+
+
+def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None:
+    """
+    Args:
+      lexicon:
+        A dict containing pronunciations. Its keys are words and values
+        are pronunciations (i.e., tokens).
+      line:
+        A line of transcript consisting of space(s) separated words.
+      oov_token:
+        The pronunciation of the oov word if a word in `line` is not present
+        in the lexicon.
+    Returns:
+      Return None.
+    """
+    s = ""
+    words = line.strip().split()
+    for i, w in enumerate(words):
+        tokens = lexicon.get(w, oov_token)
+        s += " ".join(tokens)
+        s += " "
+    print(s.strip())
+
+
+def main():
+    args = get_args()
+    assert Path(args.lexicon).is_file()
+    assert Path(args.transcript).is_file()
+    assert len(args.oov) > 0
+
+    # Only the first pronunciation of a word is kept
+    lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
+
+    lexicon = dict(lexicon)
+
+    assert args.oov in lexicon
+
+    oov_token = lexicon[args.oov]
+
+    with open(args.transcript) as f:
+        for line in f:
+            process_line(lexicon=lexicon, line=line, oov_token=oov_token)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/local/display_manifest_statistics.py b/egs/mgb2/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..d3e224905
--- /dev/null
+++ b/egs/mgb2/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest
+
+
+def main():
+    # path = "./data/fbank/cuts_train.jsonl.gz"
+    path = "./data/fbank/cuts_dev.jsonl.gz"
+    # path = "./data/fbank/cuts_test.jsonl.gz"
+
+    cuts = load_manifest(path)
+    cuts.describe()
+
+
+if __name__ == "__main__":
+    main()
+
+"""
+# train
+
+Cuts count: 1125309
+Total duration (hours): 3403.9
+Speech duration (hours): 3403.9 (100.0%)
+***
+Duration statistics (seconds):
+mean    10.9
+std     10.1
+min     0.2
+25%     5.2
+50%     7.8
+75%     12.7
+99%     52.0
+99.5%   65.1
+99.9%   99.5
+max     228.9
+
+
+# test
+Cuts count: 5365
+Total duration (hours): 9.6
+Speech duration (hours): 9.6 (100.0%)
+***
+Duration statistics (seconds):
+mean    6.4
+std     1.5
+min     1.6
+25%     5.3
+50%     6.5
+75%     7.6
+99%     9.5
+99.5%   9.7
+99.9%   10.3
+max     12.4
+
+# dev
+Cuts count: 5002
+Total duration (hours): 8.5
+Speech duration (hours): 8.5 (100.0%)
+***
+Duration statistics (seconds):
+mean    6.1
+std     1.7
+min     1.5
+25%     4.8
+50%     6.2
+75%     7.4
+99%     9.5
+99.5%   9.7
+99.9%   10.1
+max     20.3
+
+"""
diff --git a/egs/mgb2/ASR/local/generate_unique_lexicon.py b/egs/mgb2/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/mgb2/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
new file mode 100755
index 000000000..3b673db6f
--- /dev/null
+++ b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+# Copyright 2022 QCRI (author: Amir Hussein)
+# Apache 2.0
+# This script prepares the graphemic lexicon.
+
+dir=data/local/dict
+lexicon_url1="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_grapheme_lexicon_20160209.bz2";
+lexicon_url2="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_phoneme_lexicon_20140317.bz2";
+stage=0
+lang_dir=download/lm
+mkdir -p $lang_dir
+
+if [ $stage -le 0 ]; then
+  echo "$0: Downloading text for lexicon... $(date)."
+  wget --no-check-certificate -P $lang_dir $lexicon_url1
+  wget --no-check-certificate -P $lang_dir $lexicon_url2
+  bzcat $lang_dir/ar-ar_grapheme_lexicon_20160209.bz2 | sed '1,3d' | awk '{print $1}'  >  $lang_dir/grapheme_lexicon
+  bzcat $lang_dir/ar-ar_phoneme_lexicon_20140317.bz2 | sed '1,3d' | awk '{print $1}' >>  $lang_dir/phoneme_lexicon
+  cat download/lm/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> $lang_dir/uniq_words
+fi
+
+
+if [ $stage -le 0 ]; then
+  echo "$0: processing lexicon text and creating lexicon... $(date)."
+  # remove vowels and  rare alef wasla
+  cat $lang_dir/uniq_words |  sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir/grapheme_lexicon.txt
+fi
+
+echo "$0: Lexicon preparation succeeded"
diff --git a/egs/tedlium3/ASR/local/prepare_lang.py b/egs/mgb2/ASR/local/prepare_lang.py
similarity index 100%
rename from egs/tedlium3/ASR/local/prepare_lang.py
rename to egs/mgb2/ASR/local/prepare_lang.py
diff --git a/egs/mgb2/ASR/local/prepare_lang_bpe.py b/egs/mgb2/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
new file mode 100755
index 000000000..99e1fa34d
--- /dev/null
+++ b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py
@@ -0,0 +1,37 @@
+#!/usr/bin/env python3
+
+# Copyright      2022  Amir Hussein
+# Apache 2.0
+
+# This script prepares givel a column of words lexicon.
+
+import argparse
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        description="""Creates the list of characters and words in lexicon"""
+    )
+    parser.add_argument("input", type=str, help="""Input list of words file""")
+    parser.add_argument("output", type=str, help="""output graphemic lexicon""")
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    lex = {}
+    args = get_args()
+    with open(args.input, "r", encoding="utf-8") as f:
+        for line in f:
+            line = line.strip()
+            characters = list(line)
+            characters = " ".join(["V" if char == "*" else char for char in characters])
+            lex[line] = characters
+
+    with open(args.output, "w", encoding="utf-8") as fp:
+        for key in sorted(lex):
+            fp.write(key + "  " + lex[key] + "\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/local/test_prepare_lang.py b/egs/mgb2/ASR/local/test_prepare_lang.py
similarity index 100%
rename from egs/tedlium3/ASR/local/test_prepare_lang.py
rename to egs/mgb2/ASR/local/test_prepare_lang.py
diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh
new file mode 100755
index 000000000..899d15d97
--- /dev/null
+++ b/egs/mgb2/ASR/prepare.sh
@@ -0,0 +1,234 @@
+#!/usr/bin/env bash
+# Copyright 2022 Johns Hopkins University  (Amir Hussein)
+# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+set -eou pipefail
+nj=30
+stage=7
+stop_stage=1000
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. 
+#
+#  - $dl_dir/mgb2
+#      
+#      You can download the data from 
+#
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+#
+# Note: MGB2 is not available for direct 
+# download, however you can fill out the form and  
+# download it from https://arabicspeech.org/mgb2 
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+  5000
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/MGB2,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/mgb2 $dl_dir/MGB2
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare mgb2 manifest"
+  # We assume that you have downloaded the mgb2 corpus
+  # to $dl_dir/mgb2
+  mkdir -p data/manifests
+
+  lhotse prepare mgb2 $dl_dir/mgb2 data/manifests
+  
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  mkdir -p data/manifests
+  lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "Stage 3: Compute fbank for mgb2"
+  mkdir -p data/fbank
+  ./local/compute_fbank_mgb2.py
+   # shufling the data
+  gunzip -c data/fbank/cuts_train.jsonl.gz | shuf | gzip -c > data/fbank/cuts_train_shuf.jsonl.gz
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  mkdir -p data/fbank
+  ./local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Prepare phone based lang"
+  if [[ ! -e download/lm/train/text ]]; then 
+  # export train text file to build grapheme lexicon 
+  lhotse kaldi export \
+    data/manifests/mgb2_recordings_train.jsonl.gz \
+    data/manifests/mgb2_supervisions_train.jsonl.gz  \
+    download/lm/train
+  fi
+
+  lang_dir=data/lang_phone
+  mkdir -p $lang_dir
+  ./local/prep_mgb2_lexicon.sh 
+  python local/prepare_mgb2_lexicon.py  $dl_dir/lm/grapheme_lexicon.txt  $dl_dir/lm/lexicon.txt
+  (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+    cat - $dl_dir/lm/lexicon.txt |
+    sort | uniq > $lang_dir/lexicon.txt
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang.py --lang-dir $lang_dir
+  fi
+fi
+
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare BPE based lang"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
+    # We reuse words.txt from phone based lexicon
+    # so that the two can share G.pt later.
+    cp data/lang_phone/words.txt $lang_dir
+
+    if [ ! -f $lang_dir/transcript_words.txt ]; then
+      log "Generate data for BPE training"
+      files=$(
+        find "$dl_dir/lm/train" -name "text"
+      )
+      for f in ${files[@]}; do
+        cat $f | cut -d " " -f 2- | sed -r '/^\s*$/d'
+      done > $lang_dir/transcript_words.txt
+    fi
+
+    ./local/train_bpe_model.py \
+      --lang-dir $lang_dir \
+      --vocab-size $vocab_size \
+      --transcript $lang_dir/transcript_words.txt
+
+    if [ ! -f $lang_dir/L_disambig.pt ]; then
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare bigram P"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+      ./local/convert_transcript_words_to_tokens.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --transcript $lang_dir/transcript_words.txt \
+        --oov "" \
+        > $lang_dir/transcript_tokens.txt
+    fi
+
+    if [ ! -f $lang_dir/P.arpa ]; then
+      ./shared/make_kn_lm.py \
+        -ngram-order 2 \
+        -text $lang_dir/transcript_tokens.txt \
+        -lm $lang_dir/P.arpa
+    fi
+
+    if [ ! -f $lang_dir/P.fst.txt ]; then
+      python3 -m kaldilm \
+        --read-symbol-table="$lang_dir/tokens.txt" \
+        --disambig-symbol='#0' \
+        --max-order=2 \
+        $lang_dir/P.arpa > $lang_dir/P.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p data/lm
+    if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+      # It is used in building HLG
+      ./shared/make_kn_lm.py \
+          -ngram-order 3 \
+          -text $lang_dir/transcript_words.txt \
+          -lm $lang_dir/G.arpa
+
+      python3 -m kaldilm \
+        --read-symbol-table="data/lang_phone/words.txt" \
+        --disambig-symbol='#0' \
+        --max-order=3 \
+        $lang_dir/G.arpa > data/lm/G_3_gram.fst.txt
+    fi
+
+    if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+      # It is used for LM rescoring
+      ./shared/make_kn_lm.py \
+          -ngram-order 4 \
+          -text $lang_dir/transcript_words.txt \
+          -lm $lang_dir/4-gram.arpa
+
+      python3 -m kaldilm \
+        --read-symbol-table="data/lang_phone/words.txt" \
+        --disambig-symbol='#0' \
+        --max-order=4 \
+        $lang_dir/4-gram.arpa > data/lm/G_4_gram.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+  log "Stage 9: Compile HLG"
+  ./local/compile_hlg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_hlg.py --lang-dir $lang_dir
+  done
+fi
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py b/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 120000
index 000000000..a73848de9
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1 @@
+../conformer_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..02d01b343
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..72338bade
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,619 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins        (authors: Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method beam_search \
+    --beam-size 10
+
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 10
+
+(4) fast beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 18 \
+    --avg 5 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 200 \
+    --decoding-method fast_beam_search \
+    --beam-size 10 \
+    --max-contexts 4 \
+    --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_2000/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for hyp_words, ref_text in zip(hyps, texts):
+
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if params.decoding_method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    MGB2 = MGB2AsrDataModule(args)
+
+    test_cuts = MGB2.test_cuts()
+    dev_cuts = MGB2.dev_cuts()
+
+    test_dl = MGB2.test_dataloaders(test_cuts)
+    dev_dl = MGB2.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_all_dl = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_all_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..6775ee67e
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..972e44ca4
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/export.py b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..7a5d7f680
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+  --exp-dir ./pruned_transducer_stateless5/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless5/decode.py \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    assert args.jit is False, "Support torchscript will be added later"
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.eval()
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..f5279e151
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..7b417fd89
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/model.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..210374f22
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/optim.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..77ba0873b
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert sample_rate == expected_sample_rate, (
+            f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
+        )
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..ff7bfeda9
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
new file mode 120000
index 000000000..b71d7bb81
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/test_model.py
\ No newline at end of file
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..e1b623353
--- /dev/null
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1176 @@
+#!/usr/bin/env python3
+# Copyright    2022  Johns Hopkins        (authors: Amir Hussein)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 2 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 200 \
+  --num-buckets 50
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 2 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --max-duration 200	\
+  --num-buckets 50
+
+"""
+
+# xxx
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import nvidia_smi
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MGB2AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=12,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=2048,
+        help="Feedforward dimension of the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=512,
+        help="Attention dimension in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_2000/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need " "to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network)" "part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=8000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=10,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=True,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for Noam
+            "model_warm_step": 80000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+    reduction="none",
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+            reduction="none",
+        )
+        simple_loss_is_finite = torch.isfinite(simple_loss)
+        pruned_loss_is_finite = torch.isfinite(pruned_loss)
+        is_finite = simple_loss_is_finite & pruned_loss_is_finite
+        inf_flag = False
+        if not torch.all(is_finite):
+            inf_flag = True
+            logging.info(
+                "Not all losses are finite!\n"
+                f"simple_loss: {simple_loss}\n"
+                f"pruned_loss: {pruned_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=sp)
+            simple_loss = simple_loss[simple_loss_is_finite]
+            pruned_loss = pruned_loss[pruned_loss_is_finite]
+
+        simple_loss = simple_loss.sum()
+        pruned_loss = pruned_loss.sum()
+
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    # info["utterances"] = feature.size(0)
+    # # averaged input duration in frames over utterances
+    # info["utt_duration"] = feature_lens.sum().item()
+    # # averaged padding proportion over utterances
+    # info["utt_pad_proportion"] = (
+    #     ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    # )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info, inf_flag
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+    with torch.no_grad():
+        for batch_idx, batch in enumerate(valid_dl):
+            loss, loss_info, inf_flag = compute_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                batch=batch,
+                is_training=False,
+            )
+            assert loss.requires_grad is False
+            tot_loss = tot_loss + loss_info
+
+        if world_size > 1:
+            tot_loss.reduce(loss.device)
+
+        loss_value = tot_loss["loss"] / tot_loss["frames"]
+        if loss_value < params.best_valid_loss:
+            params.best_valid_epoch = params.cur_epoch
+            params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+
+        if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
+            if batch_idx < cur_batch_idx:
+                continue
+            cur_batch_idx = batch_idx
+
+            params.batch_idx_train += 1
+            batch_size = len(batch["supervisions"]["text"])
+
+            try:
+                with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                    loss, loss_info, inf_flag = compute_loss(
+                        params=params,
+                        model=model,
+                        sp=sp,
+                        batch=batch,
+                        is_training=True,
+                        warmup=(params.batch_idx_train / params.model_warm_step),
+                    )
+                # summary stats
+                tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+                # NOTE: We use reduction==sum and loss is computed over utterances
+                # in the batch and there is no normalization to it so far.
+                if not inf_flag:
+                    scaler.scale(loss).backward()
+                    scheduler.step_batch(params.batch_idx_train)
+                    scaler.step(optimizer)
+                    scaler.update()
+                    optimizer.zero_grad()
+                else:
+                    continue
+            except:  # noqa
+                display_and_save_batch(batch, params=params, sp=sp)
+                raise
+
+            if params.print_diagnostics and batch_idx == 5:
+                return
+
+            if (
+                rank == 0
+                and params.batch_idx_train > 0
+                and params.batch_idx_train % params.average_period == 0
+            ):
+                update_averaged_model(
+                    params=params,
+                    model_cur=model,
+                    model_avg=model_avg,
+                )
+
+            if (
+                params.batch_idx_train > 0
+                and params.batch_idx_train % params.save_every_n == 0
+            ):
+                params.cur_batch_idx = batch_idx
+                save_checkpoint_with_global_batch_idx(
+                    out_dir=params.exp_dir,
+                    global_batch_idx=params.batch_idx_train,
+                    model=model,
+                    model_avg=model_avg,
+                    params=params,
+                    optimizer=optimizer,
+                    scheduler=scheduler,
+                    sampler=train_dl.sampler,
+                    scaler=scaler,
+                    rank=rank,
+                )
+                del params.cur_batch_idx
+                remove_checkpoints(
+                    out_dir=params.exp_dir,
+                    topk=params.keep_last_k,
+                    rank=rank,
+                )
+
+            if batch_idx % params.log_interval == 0:
+                cur_lr = scheduler.get_last_lr()[0]
+                # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea
+                memory_debugging()
+                logging.info(
+                    f"Epoch {params.cur_epoch}, "
+                    f"batch {batch_idx}, loss[{loss_info}], "
+                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                    f"lr: {cur_lr:.2e}"
+                )
+
+                if tb_writer is not None:
+                    tb_writer.add_scalar(
+                        "train/learning_rate", cur_lr, params.batch_idx_train
+                    )
+
+                    loss_info.write_summary(
+                        tb_writer, "train/current_", params.batch_idx_train
+                    )
+                    tot_loss.write_summary(
+                        tb_writer, "train/tot_", params.batch_idx_train
+                    )
+
+            if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+                logging.info("Computing validation loss")
+                valid_info = compute_validation_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    valid_dl=valid_dl,
+                    world_size=world_size,
+                )
+                model.train()
+                logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+                if tb_writer is not None:
+                    valid_info.write_summary(
+                        tb_writer, "train/valid_", params.batch_idx_train
+                    )
+        else:
+            logging.warning(
+                f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..."
+            )
+            continue
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def memory_debugging():
+    # memory nvidia debugging
+    nvidia_smi.nvmlInit()
+
+    deviceCount = nvidia_smi.nvmlDeviceGetCount()
+    for i in range(deviceCount):
+        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
+        info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
+        logging.info(
+            "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format(
+                i,
+                nvidia_smi.nvmlDeviceGetName(handle),
+                100 * info.free / info.total,
+                info.total,
+                info.free,
+                info.used,
+            )
+        )
+
+    nvidia_smi.nvmlShutdown()
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    MGB2 = MGB2AsrDataModule(args)
+    train_cuts = MGB2.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 30 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        return 0.5 <= c.duration <= 30.0
+
+    def remove_short_and_long_text(c: Cut):
+        # Keep only text with charachters between 20 and 450
+
+        return 20 <= len(c.supervisions[0].text) <= 450
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+    train_cuts = train_cuts.filter(remove_short_and_long_text)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+    valid_cuts = MGB2.dev_cuts()
+    valid_dl = MGB2.test_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+            # (i.e. are not remembered by the decaying-average in adam), because
+            # we want to avoid these params being subject to shrinkage in adam.
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+
+                loss, _, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=0.0,
+                )
+            loss.backward()
+            # clip_grad_norm_(model.parameters(), 5.0, 2.0)
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    MGB2AsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/mgb2/ASR/shared b/egs/mgb2/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/mgb2/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py
index af54dbd07..bed3856e4 100755
--- a/egs/ptb/LM/local/sort_lm_training_data.py
+++ b/egs/ptb/LM/local/sort_lm_training_data.py
@@ -135,9 +135,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py
index 877720e7b..3790045fa 100755
--- a/egs/ptb/LM/local/test_prepare_lm_training_data.py
+++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py
@@ -54,9 +54,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh
index 70586785d..69fab999a 100755
--- a/egs/ptb/LM/prepare.sh
+++ b/egs/ptb/LM/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
@@ -19,9 +22,9 @@ dl_dir=$PWD/download
 # if the array contains xxx, yyy
 vocab_sizes=(
   500
-  1000
-  2000
-  5000
+  # 1000
+  # 2000
+  # 5000
 )
 
 # All files generated by this script are saved in "data".
@@ -39,11 +42,14 @@ log "dl_dir: $dl_dir"
 
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
   log "Stage -1: Download data"
+
+  # Caution: The downloaded data has already been normalized for LM training.
+
   if [ ! -f $dl_dir/.complete ]; then
-    url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt
-    wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt
+    url=http://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data
+    wget --directory-prefix $dl_dir $url/ptb.train.txt
+    wget --directory-prefix $dl_dir $url/ptb.valid.txt
+    wget --directory-prefix $dl_dir $url/ptb.test.txt
     touch $dl_dir/.complete
   fi
 fi
@@ -51,11 +57,15 @@ fi
 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
   log "Stage 0: Train BPE model"
 
+  # Caution: You have to use the same bpe model for training your acoustic model
+  # Caution: You have to use the same bpe model for training your acoustic model
+  # Caution: You have to use the same bpe model for training your acoustic model
+
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
-    mkdir -p $out_dir
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
     ./local/train_bpe_model.py \
-      --out-dir $out_dir \
+      --lang-dir $lang_dir \
       --vocab-size $vocab_size \
       --transcript $dl_dir/ptb.train.txt
   done
@@ -66,20 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
   # Note: ptb.train.txt has already been normalized
 
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
+    lang_dir=data/lang_bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
     mkdir -p $out_dir
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.train.txt \
       --lm-archive $out_dir/lm_data.pt
 
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.valid.txt \
       --lm-archive $out_dir/lm_data-valid.pt
 
     ./local/prepare_lm_training_data.py \
-      --bpe-model $out_dir/bpe.model \
+      --bpe-model $lang_dir/bpe.model \
       --lm-data $dl_dir/ptb.test.txt \
       --lm-archive $out_dir/lm_data-test.pt
   done
@@ -95,7 +106,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
   # in a sentence.
 
   for vocab_size in ${vocab_sizes[@]}; do
-    out_dir=data/bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
     mkdir -p $out_dir
     ./local/sort_lm_training_data.py \
       --in-lm-data $out_dir/lm_data.pt \
diff --git a/egs/ptb/LM/rnn_lm b/egs/ptb/LM/rnn_lm
new file mode 120000
index 000000000..87f29771e
--- /dev/null
+++ b/egs/ptb/LM/rnn_lm
@@ -0,0 +1 @@
+../../../icefall/rnn_lm
\ No newline at end of file
diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh
new file mode 100755
index 000000000..29c609ee1
--- /dev/null
+++ b/egs/ptb/LM/train-rnn-lm.sh
@@ -0,0 +1,67 @@
+#!/usr/bin/env bash
+
+# Please run ./prepare.sh first
+
+stage=-1
+stop_stage=100
+
+# Number of GPUs to use for training
+world_size=1
+
+# Number of epochs to train
+num_epochs=20
+
+# Use this epoch for computing ppl
+use_epoch=19
+
+# number of models to average for computing ppl
+use_avg=2
+
+exp_dir=./my-rnnlm-exp
+
+. shared/parse_options.sh || exit 1
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Training RNN LM"
+
+  ./rnn_lm/train.py \
+    --exp-dir $exp_dir \
+    --start-epoch 0 \
+    --num-epochs $num_epochs \
+    --world-size $world_size \
+    --use-fp16 0 \
+    --vocab-size 500 \
+    \
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \
+    --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \
+    \
+    --embedding-dim 800 \
+    --hidden-dim 200 \
+    --num-layers 2 \
+    --tie-weights false \
+    --batch-size 50
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Computing perplexity"
+
+  ./rnn_lm/compute_perplexity.py \
+    --exp-dir $exp_dir \
+    --epoch $use_epoch \
+    --avg $use_avg \
+    --vocab-size 500 \
+    \
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \
+    \
+    --embedding-dim 800 \
+    --hidden-dim 200 \
+    --num-layers 2 \
+    --tie-weights false \
+    --batch-size 50
+fi
diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py
index 6cb8b65ae..9bea28a41 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py
@@ -87,9 +87,7 @@ def compute_fbank_musan():
     # create chunks of Musan with duration 5 - 10 seconds
     musan_cuts = (
         CutSet.from_manifests(
-            recordings=combine(
-                part["recordings"] for part in manifests.values()
-            )
+            recordings=combine(part["recordings"] for part in manifests.values())
         )
         .cut_into_windows(10.0)
         .filter(lambda c: c.duration > 5)
@@ -108,8 +106,6 @@ def compute_fbank_musan():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
     compute_fbank_musan()
diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
index 8116e7605..20ff6d7ab 100755
--- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
+++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py
@@ -103,11 +103,7 @@ def compute_fbank_spgispeech(args):
             chunk_size=chunk_size,
         )
         start = args.start
-        stop = (
-            min(args.stop, args.num_splits)
-            if args.stop > 0
-            else args.num_splits
-        )
+        stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits
         num_digits = len(str(args.num_splits))
         for i in range(start, stop):
             idx = f"{i + 1}".zfill(num_digits)
@@ -129,9 +125,7 @@ def compute_fbank_spgispeech(args):
                 logging.info(f"{partition} already exists - skipping.")
                 continue
             logging.info(f"Processing {partition}")
-            cut_set = load_manifest_lazy(
-                src_dir / f"cuts_{partition}_raw.jsonl.gz"
-            )
+            cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz")
             cut_set = cut_set.compute_and_store_features_batch(
                 extractor=extractor,
                 storage_path=output_dir / f"feats_{partition}",
@@ -144,9 +138,7 @@ def compute_fbank_spgispeech(args):
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     args = get_args()
diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py
index 8c8f1c133..508d4acd8 100755
--- a/egs/spgispeech/ASR/local/prepare_splits.py
+++ b/egs/spgispeech/ASR/local/prepare_splits.py
@@ -55,9 +55,7 @@ def split_spgispeech_train():
 
     # Add speed perturbation
     train_cuts = (
-        train_cuts
-        + train_cuts.perturb_speed(0.9)
-        + train_cuts.perturb_speed(1.1)
+        train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1)
     )
 
     # Write the manifests to disk.
@@ -73,9 +71,7 @@ def split_spgispeech_train():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     split_spgispeech_train()
diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh
index 231ebd742..8331f94d5 100755
--- a/egs/spgispeech/ASR/prepare.sh
+++ b/egs/spgispeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=20
@@ -105,7 +108,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
     pieces=$(find data/manifests -name "cuts_train_[0-9]*.jsonl.gz")
     lhotse combine $pieces data/manifests/cuts_train.jsonl.gz
   fi
-  gunzip -c data/manifests/train_cuts.jsonl.gz | shuf | gzip -c > data/manifests/train_cuts_shuf.jsonl.gz
+  gunzip -c data/manifests/cuts_train.jsonl.gz | shuf | gzip -c > data/manifests/cuts_train_shuf.jsonl.gz
 fi
 
 if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -133,7 +136,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
     # Add special words to words.txt
     echo " 0" > $lang_dir/words.txt
     echo "!SIL 1" >> $lang_dir/words.txt
-    echo "[UNK] 2" >> $lang_dir/words.txt
+    echo " 2" >> $lang_dir/words.txt
 
     # Add regular words to words.txt
     gunzip -c data/manifests/cuts_train_raw.jsonl.gz \
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index f165f6e60..d94a92503 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -176,17 +176,13 @@ class SPGISpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "cuts_musan.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -208,9 +204,7 @@ class SPGISpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             input_transforms.append(
                 SpecAugment(
                     time_warp_factor=self.args.spec_aug_time_warp_factor,
@@ -227,9 +221,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
             )
         else:
@@ -282,9 +274,7 @@ class SPGISpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
             )
         else:
             validate = K2SpeechRecognitionDataset(
@@ -328,9 +318,7 @@ class SPGISpeechAsrDataModule:
     @lru_cache()
     def train_cuts(self) -> CutSet:
         logging.info("About to get SPGISpeech train cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_train_shuf.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz")
 
     @lru_cache()
     def dev_cuts(self) -> CutSet:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
index c39bd0530..4434aae62 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
 )
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import (
     AttributeDict,
     setup_logger,
@@ -187,8 +183,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -246,9 +241,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -263,10 +256,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -389,9 +379,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -403,18 +391,14 @@ def save_results(
     test_set_wers = dict()
     test_set_cers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        wers_filename = (
-            params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt"
         with open(wers_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -424,12 +408,8 @@ def save_results(
         # we also compute CER for spgispeech dataset.
         results_char = []
         for res in results:
-            results_char.append(
-                (res[0], list("".join(res[1])), list("".join(res[2])))
-            )
-        cers_filename = (
-            params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+        cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt"
         with open(cers_filename, "w") as f:
             cer = write_error_stats(
                 f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -438,32 +418,21 @@ def save_results(
 
         logging.info("Wrote detailed error stats to {}".format(wers_filename))
 
-    test_set_wers = {
-        k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])
-    }
-    test_set_cers = {
-        k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])
-    }
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+    test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER\tCER", file=f)
         for key in test_set_wers:
             print(
-                "{}\t{}\t{}".format(
-                    key, test_set_wers[key], test_set_cers[key]
-                ),
+                "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
                 file=f,
             )
 
     s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
     note = "\tbest for {}".format(test_set_name)
     for key in test_set_wers:
-        s += "{}\t{}\t{}{}\n".format(
-            key, test_set_wers[key], test_set_cers[key], note
-        )
+        s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
         note = ""
     logging.info(s)
 
@@ -496,9 +465,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -530,8 +497,7 @@ def main():
         ]
         if len(filenames) == 0:
             raise ValueError(
-                f"No checkpoints found for"
-                f" --iter {params.iter}, --avg {params.avg}"
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
             )
         elif len(filenames) < params.avg:
             raise ValueError(
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
index 77faa3c0e..68763808a 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ import sentencepiece as spm
 import torch
 from train import get_params, get_transducer_model
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import str2bool
 
 
@@ -119,8 +115,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -196,9 +191,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
index dda29b3e5..d943180b1 100755
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py
@@ -77,9 +77,7 @@ from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -155,8 +153,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need to be "
-        "changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -179,8 +176,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,8 +199,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -554,23 +549,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -733,9 +721,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
index 4582609ac..602e50d29 100755
--- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
+++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py
@@ -84,9 +84,7 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80):
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -112,9 +110,7 @@ def get_args():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py
index 2c5b8b8b3..1262baf63 100755
--- a/egs/tal_csasr/ASR/local/prepare_char.py
+++ b/egs/tal_csasr/ASR/local/prepare_char.py
@@ -87,9 +87,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/tal_csasr/ASR/local/prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--lang-dir", type=str, help="The lang dir, data/lang_phone"
-    )
+    parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
     return parser.parse_args()
 
 
diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/tal_csasr/ASR/local/test_prepare_lang.py
+++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
     fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa.draw("L.pdf", title="L")
 
-    fsa_disambig = lexicon_to_fst(
-        lexicon_disambig, phone2id=phone2id, word2id=word2id
-    )
+    fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
     fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
     fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
     fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/tal_csasr/ASR/local/text2token.py
+++ b/egs/tal_csasr/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh
index 340521ad8..c5d498d74 100755
--- a/egs/tal_csasr/ASR/prepare.sh
+++ b/egs/tal_csasr/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
@@ -9,9 +12,12 @@ stop_stage=100
 # directories and files. If not, they will be downloaded
 # by this script automatically.
 #
-#  - $dl_dir/tal_csasr
+#  - $dl_dir/TALCS_corpus
 #      You can find three directories:train_set, dev_set, and test_set.
 #      You can get it from https://ai.100tal.com/dataset
+#     - dev_set
+#     - test_set
+#     - train_set
 #
 #  - $dl_dir/musan
 #      This directory contains the following directories downloaded from
@@ -41,7 +47,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
   log "Stage 0: Download data"
   # Before you run this script, you must get the TAL_CSASR dataset
   # from https://ai.100tal.com/dataset
-  mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr
+  if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then
+    mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr
+  fi
 
   # If you have pre-downloaded it to /path/to/TALCS_corpus,
   # you can create a symlink
@@ -113,7 +121,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
   fi
 
   # Prepare text.
-  # Note: in Linux, you can install jq with the  following command:
+  # Note: in Linux, you can install jq with the following command:
   # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
   # 2. chmod +x ./jq
   # 3. cp jq /usr/bin
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 49bfb148b..2240c1c1d 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -222,17 +222,13 @@ class TAL_CSASRAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -254,9 +250,7 @@ class TAL_CSASRAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -300,9 +294,7 @@ class TAL_CSASRAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -360,9 +352,7 @@ class TAL_CSASRAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index b624913f5..3bfb832fb 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -208,8 +208,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -268,9 +267,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     zh_hyps = []
     en_hyps = []
@@ -303,10 +300,7 @@ def decode_one_batch(
             hyps.append(chars_new)
             zh_hyps.append(zh_text)
             en_hyps.append(en_text)
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -375,9 +369,7 @@ def decode_one_batch(
                     f"Unsupported decoding method: {params.decoding_method}"
                 )
             for i in range(encoder_out.size(0)):
-                hyp = sp.decode(
-                    [lexicon.token_table[idx] for idx in hyp_tokens[i]]
-                )
+                hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
                 chars = pattern.split(hyp.upper())
                 chars_new = []
                 zh_text = []
@@ -506,9 +498,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results, zh_results, en_results
 
 
@@ -519,18 +509,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -540,10 +526,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -585,9 +568,7 @@ def main():
         params.suffix += f"-max-contexts-{params.max_contexts}"
         params.suffix += f"-max-states-{params.max_states}"
     elif "beam_search" in params.decoding_method:
-        params.suffix += (
-            f"-{params.decoding_method}-beam-size-{params.beam_size}"
-        )
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
     else:
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -619,9 +600,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -648,9 +629,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index 8f900208a..bc33dd160 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -139,8 +139,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     add_model_arguments(parser)
@@ -176,9 +175,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -205,9 +204,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -277,9 +276,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index dbe213b24..3305f5bd3 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -197,10 +196,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -263,15 +261,11 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=features, x_lens=feature_lengths
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
 
     num_waves = encoder_out.size(0)
     hyps = []
@@ -367,9 +361,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index ca35eba45..43f3231ba 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -86,9 +86,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -214,8 +212,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -238,8 +235,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -262,8 +258,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -600,11 +595,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -634,22 +625,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -828,9 +812,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -944,7 +926,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md
index 511b19f73..38eaa8f44 100644
--- a/egs/tedlium3/ASR/RESULTS.md
+++ b/egs/tedlium3/ASR/RESULTS.md
@@ -1,5 +1,88 @@
 ## Results
 
+### TedLium3 BPE training results (Conformer-CTC 2)
+
+#### [conformer_ctc2](./conformer_ctc2)
+
+See  for more details.
+
+The tensorboard log can be found at
+
+
+You can find a pretrained model and decoding results at:
+
+
+Number of model parameters: 101141699, i.e., 101.14 M
+
+The WERs are
+
+|                          | dev        | test        | comment             |
+|--------------------------|------------|-------------|---------------------|
+| ctc decoding             | 6.45       | 5.96        | --epoch 38 --avg 26 |
+| 1best                    | 5.92       | 5.51        | --epoch 38 --avg 26 |
+| whole lattice rescoring  | 5.96       | 5.47        | --epoch 38 --avg 26 |
+| attention decoder        | 5.60       | 5.33        | --epoch 38 --avg 26 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc2/train.py \
+    --world-size 4 \
+    --num-epochs 40 \
+    --exp-dir conformer_ctc2/exp \
+    --max-duration 350 \
+    --use-fp16 true
+```
+
+The decoding command is:
+```
+epoch=38
+avg=26
+
+## ctc decoding
+./conformer_ctc2/decode.py \
+  --method ctc-decoding \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## 1best
+./conformer_ctc2/decode.py \
+  --method 1best \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## whole lattice rescoring
+./conformer_ctc2/decode.py \
+  --method whole-lattice-rescoring \
+  --exp-dir conformer_ctc2/exp \
+  --lm-path data/lm/G_4_gram_big.pt \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+
+## attention decoder
+./conformer_ctc2/decode.py \
+  --method attention-decoder \
+  --exp-dir conformer_ctc2/exp \
+  --lang-dir data/lang_bpe_500 \
+  --result-dir conformer_ctc2/exp \
+  --max-duration 500 \
+  --epoch $epoch \
+  --avg $avg
+```
+
 ### TedLium3 BPE training results (Pruned Transducer)
 
 #### 2022-03-21
diff --git a/egs/tedlium3/ASR/conformer_ctc2/__init__.py b/egs/tedlium3/ASR/conformer_ctc2/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
new file mode 120000
index 000000000..49b2ee483
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
@@ -0,0 +1 @@
+../transducer_stateless/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/attention.py b/egs/tedlium3/ASR/conformer_ctc2/attention.py
new file mode 100644
index 000000000..178cd7e62
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/attention.py
@@ -0,0 +1,201 @@
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+from scaling import ScaledLinear
+
+
+class MultiheadAttention(torch.nn.Module):
+    """Allows the model to jointly attend to information
+    from different representation subspaces. This is a modified
+    version of the original version of multihead attention
+    (see Attention Is All You Need )
+    with replacement of input / output projection layers
+    with newly introduced ScaleLinear layer
+    (see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py).
+
+    Args:
+        embed_dim:
+          total dimension of the model.
+        num_heads:
+          number of parallel attention heads. Note that embed_dim will be split
+          across num_heads, i.e. each head will have dimension (embed_dim // num_heads).
+        dropout:
+          dropout probability on attn_output_weights. (default=0.0).
+        bias:
+          if specified, adds bias to input / output projection layers (default=True).
+        add_bias_kv:
+          if specified, adds bias to the key and value sequences at dim=0 (default=False).
+        add_zero_attn:
+          if specified, adds a new batch of zeros to the key and value sequences
+          at dim=1 (default=False).
+        batch_first:
+          if True, then the input and output tensors are provided as
+          (batch, seq, feature), otherwise (seq, batch, feature) (default=False).
+
+    Examples::
+        >>> multihead_attn = MultiheadAttention(embed_dim, num_heads)
+        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        bias: bool = True,
+        add_bias_kv: bool = False,
+        add_zero_attn: bool = False,
+        batch_first: bool = False,
+        device: Union[torch.device, str, None] = None,
+        dtype: Union[torch.dtype, str, None] = None,
+    ) -> None:
+
+        super().__init__()
+
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.batch_first = batch_first
+
+        if embed_dim % num_heads != 0:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads. "
+                "Got embedding dim vs number 0f heads: "
+                f"{embed_dim} vs {num_heads}"
+            )
+
+        self.head_dim = embed_dim // num_heads
+
+        self.in_proj = ScaledLinear(
+            embed_dim,
+            3 * embed_dim,
+            bias=bias,
+            device=device,
+            dtype=dtype,
+        )
+        self.out_proj = ScaledLinear(
+            embed_dim,
+            embed_dim,
+            bias=bias,
+            initial_scale=0.25,
+            device=device,
+            dtype=dtype,
+        )
+
+        if add_bias_kv:
+            self.bias_k = torch.nn.Parameter(
+                torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
+            )
+            self.bias_v = torch.nn.Parameter(
+                torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
+            )
+        else:
+            self.register_parameter("bias_k", None)
+            self.register_parameter("bias_v", None)
+
+        self.add_zero_attn = add_zero_attn
+
+        self._reset_parameters()
+
+    def _reset_parameters(self) -> None:
+        if self.bias_k is not None:
+            torch.nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            torch.nn.init.xavier_normal_(self.bias_v)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = True,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+            query:
+              Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q)
+              when batch_first=True, where L is the target sequence length, N is the batch size,
+              and E_q is the query embedding dimension embed_dim. Queries are compared against
+              key-value pairs to produce the output. See "Attention Is All You Need" for more details.
+            key:
+              Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when
+              batch_first=True, where S is the source sequence length, N is the batch size, and
+              E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details.
+            value:
+              Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when
+              batch_first=True, where S is the source sequence length, N is the batch size, and
+              E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details.
+            key_padding_mask:
+              If specified, a mask of shape (N, S) indicating which elements within key
+              to ignore for the purpose of attention (i.e. treat as "padding").
+              Binary and byte masks are supported. For a binary mask, a True value indicates
+              that the corresponding key value will be ignored for the purpose of attention.
+              For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
+            need_weights:
+              If specifid, returns attn_output_weights in addition to attn_outputs (default=True).
+            attn_mask:
+              If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
+              (L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length,
+              and S is the source sequence length. A 2D mask will be broadcasted across the batch while
+              a 3D mask allows for a different mask for each entry in the batch.
+              Binary, byte, and float masks are supported. For a binary mask, a True value indicates
+              that the corresponding position is not allowed to attend. For a byte mask, a non-zero
+              value indicates that the corresponding position is not allowed to attend. For a float mask,
+              the mask values will be added to the attention weight.
+
+        Returns:
+            attn_output:
+              Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True,
+              where L is the target sequence length, N is the batch size, and E is the embedding dimension
+              embed_dim.
+            attn_output_weights:
+              Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence
+              length, and S is the source sequence length. Only returned when need_weights=True.
+        """
+        if self.batch_first:
+            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
+
+        (
+            attn_output,
+            attn_output_weights,
+        ) = torch.nn.functional.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            self.embed_dim,
+            self.num_heads,
+            in_proj_weight=self.in_proj.get_weight(),
+            in_proj_bias=self.in_proj.get_bias(),
+            bias_k=self.bias_k,
+            bias_v=self.bias_v,
+            add_zero_attn=self.add_zero_attn,
+            dropout_p=self.dropout,
+            out_proj_weight=self.out_proj.get_weight(),
+            out_proj_bias=self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+        if self.batch_first:
+            return attn_output.transpose(1, 0), attn_output_weights
+        return attn_output, attn_output_weights
diff --git a/egs/tedlium3/ASR/conformer_ctc2/combiner.py b/egs/tedlium3/ASR/conformer_ctc2/combiner.py
new file mode 100644
index 000000000..ff526029d
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/combiner.py
@@ -0,0 +1,244 @@
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+
+class RandomCombine(torch.nn.Module):
+    """
+    This module combines a list of Tensors, all with the same shape, to
+    produce a single output of that same shape which, in training time,
+    is a random combination of all the inputs; but which in test time
+    will be just the last input.
+    The idea is that the list of Tensors will be a list of outputs of multiple
+    conformer layers.  This has a similar effect as iterated loss. (See:
+    DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
+    NETWORKS).
+    """
+
+    def __init__(
+        self,
+        num_inputs: int,
+        final_weight: float = 0.5,
+        pure_prob: float = 0.5,
+        stddev: float = 2.0,
+    ) -> None:
+        """
+        Args:
+          num_inputs:
+            The number of tensor inputs, which equals the number of layers'
+            outputs that are fed into this module.  E.g. in an 18-layer neural
+            net if we output layers 16, 12, 18, num_inputs would be 3.
+          final_weight:
+            The amount of weight or probability we assign to the
+            final layer when randomly choosing layers or when choosing
+            continuous layer weights.
+          pure_prob:
+            The probability, on each frame, with which we choose
+            only a single layer to output (rather than an interpolation)
+          stddev:
+            A standard deviation that we add to log-probs for computing
+            randomized weights.
+        The method of choosing which layers, or combinations of layers, to use,
+        is conceptually as follows::
+            With probability `pure_prob`::
+               With probability `final_weight`: choose final layer,
+               Else: choose random non-final layer.
+            Else::
+               Choose initial log-weights that correspond to assigning
+               weight `final_weight` to the final layer and equal
+               weights to other layers; then add Gaussian noise
+               with variance `stddev` to these log-weights, and normalize
+               to weights (note: the average weight assigned to the
+               final layer here will not be `final_weight` if stddev>0).
+        """
+        super().__init__()
+        assert 0 <= pure_prob <= 1, pure_prob
+        assert 0 < final_weight < 1, final_weight
+        assert num_inputs >= 1, num_inputs
+
+        self.num_inputs = num_inputs
+        self.final_weight = final_weight
+        self.pure_prob = pure_prob
+        self.stddev = stddev
+
+        self.final_log_weight = (
+            torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1))
+            .log()
+            .item()
+        )
+
+    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
+        """Forward function.
+        Args:
+          inputs:
+            A list of Tensor, e.g. from various layers of a transformer.
+            All must be the same shape, of (*, num_channels)
+        Returns:
+          A Tensor of shape (*, num_channels). In test mode
+          this is just the final input.
+        """
+        num_inputs = self.num_inputs
+        assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}"
+        if not self.training or torch.jit.is_scripting() or len(inputs) == 1:
+            return inputs[-1]
+
+        # Shape of weights: (*, num_inputs)
+        num_channels = inputs[0].shape[-1]
+        num_frames = inputs[0].numel() // num_channels
+
+        ndim = inputs[0].ndim
+        # stacked_inputs: (num_frames, num_channels, num_inputs)
+        stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
+            (num_frames, num_channels, num_inputs)
+        )
+
+        # weights: (num_frames, num_inputs)
+        weights = self._get_random_weights(
+            inputs[0].dtype, inputs[0].device, num_frames
+        )
+
+        weights = weights.reshape(num_frames, num_inputs, 1)
+        # ans: (num_frames, num_channels, 1)
+        ans = torch.matmul(stacked_inputs, weights)
+        # ans: (*, num_channels)
+
+        ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
+
+        return ans
+
+    def _get_random_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired
+        Returns:
+          A tensor of shape (num_frames, self.num_inputs), such that
+          `ans.sum(dim=1)` is all ones.
+        """
+        pure_prob = self.pure_prob
+        if pure_prob == 0.0:
+            return self._get_random_mixed_weights(dtype, device, num_frames)
+        elif pure_prob == 1.0:
+            return self._get_random_pure_weights(dtype, device, num_frames)
+        else:
+            p = self._get_random_pure_weights(dtype, device, num_frames)
+            m = self._get_random_mixed_weights(dtype, device, num_frames)
+            return torch.where(
+                torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
+            )
+
+    def _get_random_pure_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random one-hot weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired.
+        Returns:
+          A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
+          exactly one weight equal to 1.0 on each frame.
+        """
+        final_prob = self.final_weight
+
+        # final contains self.num_inputs - 1 in all elements
+        final = torch.full((num_frames,), self.num_inputs - 1, device=device)
+        # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
+        nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
+
+        indexes = torch.where(
+            torch.rand(num_frames, device=device) < final_prob, final, nonfinal
+        )
+        ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(
+            dtype=dtype
+        )
+        return ans
+
+    def _get_random_mixed_weights(
+        self, dtype: torch.dtype, device: torch.device, num_frames: int
+    ) -> torch.Tensor:
+        """Return a tensor of random one-hot weights, of shape
+        `(num_frames, self.num_inputs)`,
+        Args:
+          dtype:
+            The data-type desired for the answer, e.g. float, double.
+          device:
+            The device needed for the answer.
+          num_frames:
+            The number of sets of weights desired.
+        Returns:
+          A tensor of shape (num_frames, self.num_inputs), which elements
+          in [0..1] that sum to one over the second axis, i.e.
+          `ans.sum(dim=1)` is all ones.
+        """
+        logprobs = (
+            torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
+            * self.stddev
+        )
+        logprobs[:, -1] += self.final_log_weight
+        return logprobs.softmax(dim=1)
+
+
+def _test_random_combine(
+    final_weight: float,
+    pure_prob: float,
+    stddev: float,
+) -> None:
+    print(
+        f"_test_random_combine: final_weight={final_weight}, "
+        f"pure_prob={pure_prob}, stddev={stddev}"
+    )
+    num_inputs = 3
+    num_channels = 50
+    m = RandomCombine(
+        num_inputs=num_inputs,
+        final_weight=final_weight,
+        pure_prob=pure_prob,
+        stddev=stddev,
+    )
+
+    x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
+
+    y = m(x)
+    assert y.shape == x[0].shape
+    assert torch.allclose(y, x[0])  # .. since actually all ones.
+
+
+def _test_random_combine_main() -> None:
+    _test_random_combine(0.999, 0, 0.0)
+    _test_random_combine(0.5, 0, 0.0)
+    _test_random_combine(0.999, 0, 0.0)
+    _test_random_combine(0.5, 0, 0.3)
+    _test_random_combine(0.5, 1, 0.3)
+    _test_random_combine(0.5, 0.5, 0.3)
+
+
+if __name__ == "__main__":
+    _test_random_combine_main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/conformer.py b/egs/tedlium3/ASR/conformer_ctc2/conformer.py
new file mode 100644
index 000000000..fad2f371f
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/conformer.py
@@ -0,0 +1,1033 @@
+#!/usr/bin/env python3
+# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+#                2022  Xiaomi Corp.                              (author: Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+import warnings
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from combiner import RandomCombine
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from transformer import Supervisions, Transformer, encoder_padding_mask
+
+
+class Conformer(Transformer):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+        aux_layer_period: int = 3,
+    ) -> None:
+        """
+        Args:
+          num_features (int):
+            number of input features.
+          num_classes (int):
+            number of output classes.
+          subsampling_factor (int):
+            subsampling factor of encoder;
+            currently, subsampling_factor MUST be 4.
+          d_model (int):
+            attention dimension, also the output dimension.
+          nhead (int):
+            number of heads in multi-head attention;
+            must satisfy d_model // nhead == 0.
+          dim_feedforward (int):
+            feedforward dimention.
+          num_encoder_layers (int):
+            number of encoder layers.
+          num_decoder_layers (int):
+            number of decoder layers.
+          dropout (float):
+            dropout rate.
+          layer_dropout (float):
+            layer-dropout rate.
+          cnn_module_kernel (int):
+            kernel size of convolution module.
+          aux_layer_period (int):
+            determines the auxiliary encoder layers.
+        """
+
+        super().__init__(
+            num_features=num_features,
+            num_classes=num_classes,
+            subsampling_factor=subsampling_factor,
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=num_encoder_layers,
+            num_decoder_layers=num_decoder_layers,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+        )
+
+        self.num_features = num_features
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape (N, T, num_features)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_features -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+        encoder_layer = ConformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+            cnn_module_kernel=cnn_module_kernel,
+        )
+
+        # aux_layers from 1/3
+        self.encoder = ConformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            aux_layers=list(
+                range(
+                    num_encoder_layers // 3,
+                    num_encoder_layers - 1,
+                    aux_layer_period,
+                )
+            ),
+        )
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        supervisions: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            the input tensor. Its shape is (batch_size, seq_len, feature_dim).
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute encoder padding mask, which is used as memory key padding
+            mask for the decoder.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up".  It is used
+            to turn modules on sequentially.
+
+        Returns:
+          torch.Tensor: Predictor tensor of dimension (S, N, C).
+          torch.Tensor: Mask tensor of dimension (N, S)
+        """
+        x = self.encoder_embed(x)
+        x, pos_emb = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, S, C) -> (S, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+
+        x = self.encoder(
+            x, pos_emb, src_key_padding_mask=mask, warmup=warmup
+        )  # (S, N, C)
+
+        return x, mask
+
+
+class ConformerEncoderLayer(nn.Module):
+    """
+    ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
+    See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
+
+    Examples:
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = encoder_layer(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+        cnn_module_kernel: int = 31,
+    ) -> None:
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+          cnn_module_kernel (int):
+            kernel size of convolution module (default=31).
+        """
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.feed_forward_macaron = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+          src:
+            the sequence to the encoder layer of shape (S, N, C) (required).
+          pos_emb:
+            positional embedding tensor of shape (N, 2*S-1, C) (required).
+          src_mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of of layers; if < 1.0, we will
+            bypass layers more frequently.
+
+        Returns:
+            Output tensor of the shape (S, N, C), where
+            S is the source sequence length,
+            N is the batch size,
+            C is the feature number
+        """
+        src_orig = src
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        # macaron style feed forward module
+        src = src + self.dropout(self.feed_forward_macaron(src))
+
+        # multi-headed self-attention module
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            pos_emb=pos_emb,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+
+        src = src + self.dropout(src_att)
+
+        # convolution module
+        src = src + self.dropout(self.conv_module(src))
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1 - alpha) * src_orig
+
+        return src
+
+
+class ConformerEncoder(nn.Module):
+    """
+    ConformerEncoder is a stack of N encoder layers
+
+    Examples:
+        >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+        >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = conformer_encoder(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+
+        """
+        Args:
+          encoder_layer:
+            an instance of the ConformerEncoderLayer() class (required).
+          num_layers:
+            the number of sub-encoder-layers in the encoder (required).
+          aux_layers:
+            list of indexes of sub-encoder-layers outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layers in turn.
+
+        Args:
+          src:
+            the sequence to the encoder of shape (S, N, C) (required).
+          pos_emb:
+            positional embedding tensor of shape (N, 2*S-1, C) (required).
+          mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = src
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                pos_emb,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+    """
+    Relative positional encoding module.
+
+    See: Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+    """
+
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+        """
+        Construct an PositionalEncoding object.
+
+        Args:
+          d_model: Embedding dimension.
+          dropout_rate: Dropout rate.
+          max_len: Maximum input length.
+
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """
+        Reset the positional encodings.
+
+        Args:
+          x:
+            input tensor (N, T, C), where
+            T is the source sequence length,
+            N is the batch size.
+            C is the feature number.
+
+        """
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(1) >= x.size(1) * 2 - 1:
+                # Note: TorchScript doesn't implement operator== for torch.Device
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` means to the position of query vecotr and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Add positional encoding.
+
+        Args:
+          x:
+            input tensor (N, T, C).
+
+        Returns:
+          torch.Tensor: Encoded tensor (N, T, C).
+          torch.Tensor: Encoded tensor (N, 2*T-1, C), where
+          T is the source sequence length,
+          N is the batch size.
+          C is the feature number.
+
+        """
+        self.extend_pe(x)
+        pos_emb = self.pe[
+            :,
+            self.pe.size(1) // 2
+            - x.size(1)
+            + 1 : self.pe.size(1) // 2  # noqa E203
+            + x.size(1),
+        ]
+        return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+    """
+    Multi-Head Attention layer with relative position encoding
+    See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context".
+
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+    ) -> None:
+        """
+        Args:
+          embed_dim:
+            total dimension of the model.
+          num_heads:
+            parallel attention heads.
+          dropout:
+            a Dropout layer on attn_output_weights. Default: 0.0.
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        assert (
+            self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
+        self.out_proj = ScaledLinear(
+            embed_dim, embed_dim, bias=True, initial_scale=0.25
+        )
+
+        # linear transformation for positional encoding.
+        self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
+        self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
+        self._reset_parameters()
+
+    def _pos_bias_u(self):
+        return self.pos_bias_u * self.pos_bias_u_scale.exp()
+
+    def _pos_bias_v(self):
+        return self.pos_bias_v * self.pos_bias_v_scale.exp()
+
+    def _reset_parameters(self) -> None:
+        nn.init.normal_(self.pos_bias_u, std=0.01)
+        nn.init.normal_(self.pos_bias_v, std=0.01)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_emb: torch.Tensor,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          query, key, value: map a query and a set of key-value pairs to an output.
+          pos_emb: Positional embedding tensor
+          key_padding_mask: if provided, specified padding elements in the key will
+                            be ignored by the attention. When given a binary mask
+                            and a value is True, the corresponding value on the attention
+                            layer will be ignored. When given a byte mask and a value is
+                            non-zero, the corresponding value on the attention layer will be ignored.
+          need_weights: output attn_output_weights.
+          attn_mask: 2D or 3D mask that prevents attention to certain positions.
+                     A 2D mask will be broadcasted for all the batches while a 3D
+                     mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+          - Inputs:
+          - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the position
+            with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+          - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+          - Outputs:
+          - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+          - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+        return self.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            pos_emb,
+            self.embed_dim,
+            self.num_heads,
+            self.in_proj.get_weight(),
+            self.in_proj.get_bias(),
+            self.dropout,
+            self.out_proj.get_weight(),
+            self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+        )
+
+    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Compute relative positional encoding.
+
+        Args:
+          x:
+            input tensor (batch, head, time1, 2*time1-1).
+            time1 means the length of query vector.
+
+        Returns:
+          torch.Tensor: tensor of shape (batch, head, time1, time2)
+          (note: time2 has the same value as time1, but it is for
+          the key, while time1 is for the query).
+        """
+        (batch_size, num_heads, time1, n) = x.shape
+        assert n == 2 * time1 - 1
+        # Note: TorchScript requires explicit arg for stride()
+        batch_stride = x.stride(0)
+        head_stride = x.stride(1)
+        time1_stride = x.stride(2)
+        n_stride = x.stride(3)
+        return x.as_strided(
+            (batch_size, num_heads, time1, time1),
+            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+            storage_offset=n_stride * (time1 - 1),
+        )
+
+    def multi_head_attention_forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_emb: torch.Tensor,
+        embed_dim_to_check: int,
+        num_heads: int,
+        in_proj_weight: torch.Tensor,
+        in_proj_bias: torch.Tensor,
+        dropout_p: float,
+        out_proj_weight: torch.Tensor,
+        out_proj_bias: torch.Tensor,
+        training: bool = True,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          query, key, value: map a query and a set of key-value pairs to an output.
+          pos_emb: Positional embedding tensor
+          embed_dim_to_check: total dimension of the model.
+          num_heads: parallel attention heads.
+          in_proj_weight, in_proj_bias: input projection weight and bias.
+          dropout_p: probability of an element to be zeroed.
+          out_proj_weight, out_proj_bias: the output projection weight and bias.
+          training: apply dropout if is ``True``.
+          key_padding_mask: if provided, specified padding elements in the key will
+                            be ignored by the attention. This is an binary mask.
+                            When the value is True, the corresponding value on the
+                            attention layer will be filled with -inf.
+          need_weights: output attn_output_weights.
+          attn_mask: 2D or 3D mask that prevents attention to certain positions.
+                     A 2D mask will be broadcasted for all the batches while a 3D
+                     mask allows to specify a different mask for the entries of each batch.
+
+        Shape:
+          Inputs:
+          - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+          - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+          - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+            length, N is the batch size, E is the embedding dimension.
+          - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+            will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+          - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+          Outputs:
+          - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+          - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+
+        tgt_len, bsz, embed_dim = query.size()
+        assert embed_dim == embed_dim_to_check
+        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+        head_dim = embed_dim // num_heads
+        assert (
+            head_dim * num_heads == embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        scaling = float(head_dim) ** -0.5
+
+        if torch.equal(query, key) and torch.equal(key, value):
+            # self-attention
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
+
+        elif torch.equal(key, value):
+            # encoder-decoder attention
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+        else:
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = embed_dim * 2
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            k = nn.functional.linear(key, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim * 2
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            v = nn.functional.linear(value, _w, _b)
+
+        if attn_mask is not None:
+            assert (
+                attn_mask.dtype == torch.float32
+                or attn_mask.dtype == torch.float64
+                or attn_mask.dtype == torch.float16
+                or attn_mask.dtype == torch.uint8
+                or attn_mask.dtype == torch.bool
+            ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+                attn_mask.dtype
+            )
+            if attn_mask.dtype == torch.uint8:
+                warnings.warn(
+                    "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+                )
+                attn_mask = attn_mask.to(torch.bool)
+
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+            elif attn_mask.dim() == 3:
+                if list(attn_mask.size()) != [
+                    bsz * num_heads,
+                    query.size(0),
+                    key.size(0),
+                ]:
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+            else:
+                raise RuntimeError(
+                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
+                )
+            # attn_mask's dim is 3 now.
+
+        # convert ByteTensor key_padding_mask to bool
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+            warnings.warn(
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+            )
+            key_padding_mask = key_padding_mask.to(torch.bool)
+
+        q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
+        k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+        src_len = k.size(0)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+                key_padding_mask.size(0), bsz
+            )
+            assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+                key_padding_mask.size(1), src_len
+            )
+
+        q = q.transpose(0, 1)  # (batch, time1, head, d_k)
+
+        pos_emb_bsz = pos_emb.size(0)
+        assert pos_emb_bsz in (1, bsz)  # actually it is 1
+        p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        q_with_bias_u = (q + self._pos_bias_u()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        q_with_bias_v = (q + self._pos_bias_v()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+
+        # compute matrix b and matrix d
+        matrix_bd = torch.matmul(
+            q_with_bias_v, p.transpose(-2, -1)
+        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+
+        assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+            else:
+                attn_output_weights += attn_mask
+
+        if key_padding_mask is not None:
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            attn_output_weights = attn_output_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                float("-inf"),
+            )
+            attn_output_weights = attn_output_weights.view(
+                bsz * num_heads, tgt_len, src_len
+            )
+
+        attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+        attn_output_weights = nn.functional.dropout(
+            attn_output_weights, p=dropout_p, training=training
+        )
+
+        attn_output = torch.bmm(attn_output_weights, v)
+        assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+        attn_output = (
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
+
+        if need_weights:
+            # average attention weights over heads
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            return attn_output, attn_output_weights.sum(dim=1) / num_heads
+        else:
+            return attn_output, None
+
+
+class ConvolutionModule(nn.Module):
+    def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
+        """
+        ConvolutionModule in Conformer model.
+        Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+        Construct a ConvolutionModule object.
+
+        Args:
+          channels (int):
+            the number of channels of conv layers.
+          kernel_size (int):
+            kernerl size of conv layers.
+          bias (bool):
+            whether to use bias in conv layers (default=True).
+        """
+        super().__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        self.pointwise_conv1 = ScaledConv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+
+        # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
+        # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+        # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+        # between 50 and 100 for different channels.  This will cause very peaky and
+        # sparse derivatives for the sigmoid gating function, which will tend to make
+        # the loss function not learn effectively.  (for most layers the average absolute values
+        # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+        # layers, which likely breaks down as 0.5 for the "linear" half and
+        # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
+        # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+        # it will be in a better position to start learning something, i.e. to latch onto
+        # the correct range.
+        self.deriv_balancer1 = ActivationBalancer(
+            channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
+        )
+
+        self.depthwise_conv = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            groups=channels,
+            bias=bias,
+        )
+
+        self.deriv_balancer2 = ActivationBalancer(
+            channel_dim=1, min_positive=0.05, max_positive=1.0
+        )
+
+        self.activation = DoubleSwish()
+
+        self.pointwise_conv2 = ScaledConv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+            initial_scale=0.25,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Compute convolution module.
+
+        Args:
+          x:
+            input tensor of shape (T, N, C).
+
+        Returns:
+          torch.Tensor: Output tensor (T, N, C), where
+          T is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        x = x.permute(1, 2, 0)  # (#batch, channels, time).
+
+        # GLU mechanism
+        x = self.pointwise_conv1(x)  # (batch, 2*channels, time)
+
+        x = self.deriv_balancer1(x)
+        x = nn.functional.glu(x, dim=1)  # (batch, channels, time)
+
+        # 1D Depthwise Conv
+        x = self.depthwise_conv(x)
+
+        x = self.deriv_balancer2(x)
+        x = self.activation(x)
+
+        x = self.pointwise_conv2(x)  # (batch, channel, time)
+
+        return x.permute(2, 0, 1)
diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py
new file mode 100755
index 000000000..28d39de70
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py
@@ -0,0 +1,896 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+#                                            Fangjun Kuang,
+#                                            Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import TedLiumAsrDataModule
+from conformer import Conformer
+from train import add_model_arguments
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.decode import (
+    get_lattice,
+    nbest_decoding,
+    nbest_oracle,
+    one_best_decoding,
+    rescore_with_attention_decoder,
+    rescore_with_n_best_list,
+    rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    get_texts,
+    load_averaged_model,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="attention-decoder",
+        help="""Decoding method.
+        Supported values are:
+            - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+              model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+              It needs neither a lexicon nor an n-gram LM.
+            - (1) ctc-greedy-search. It only use CTC output and a sentence piece
+              model for decoding. It produces the same results with ctc-decoding.
+            - (2) 1best. Extract the best path from the decoding lattice as the
+              decoding result.
+            - (3) nbest. Extract n paths from the decoding lattice; the path
+              with the highest score is the decoding result.
+            - (4) nbest-rescoring. Extract n paths from the decoding lattice,
+              rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+              the highest score is the decoding result.
+            - (5) whole-lattice-rescoring. Rescore the decoding lattice with an
+              n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+              is the decoding result.
+            - (6) attention-decoder. Extract n paths from the LM rescored
+              lattice, the path with the highest score is the decoding result.
+            - (7) nbest-oracle. Its WER is the lower bound of any n-best
+              rescoring method can achieve. Useful for debugging n-best
+              rescoring method.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=100,
+        help="""Number of paths for n-best based decoding method.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""The scale to be applied to `lattice.scores`.
+        It's needed if you use any kinds of n-best based rescoring.
+        Used only when "method" is one of the following values:
+        nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+        A smaller value results in more unique paths.
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--lm-path",
+        type=str,
+        default="data/lm/G_4_gram.pt",
+        help="""The n-gram LM dir for rescoring.
+        It should contain either lm_fname.pt or lm_fname.fst.txt
+        """,
+    )
+
+    parser.add_argument(
+        "--result-dir",
+        type=str,
+        default="conformer_ctc2/exp/results",
+        help="Directory to store results.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+    """
+    params = AttributeDict(
+        {
+            # parameters for conformer
+            "subsampling_factor": 4,
+            "feature_dim": 80,
+            # parameters for decoding
+            "search_beam": 15,
+            "output_beam": 8,
+            "min_active_states": 10,
+            "max_active_states": 7000,
+            "use_double_scores": True,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def ctc_greedy_search(
+    ctc_probs: torch.Tensor,
+    mask: torch.Tensor,
+) -> List[List[int]]:
+    """Apply CTC greedy search
+    Args:
+      ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
+      mask (torch.Tensor): (batch, max_len)
+    Returns:
+      best path result
+    """
+
+    _, max_index = ctc_probs.max(2)  # (B, maxlen)
+    max_index = max_index.masked_fill_(mask, 0)  # (B, maxlen)
+
+    ret_hyps = []
+    for hyp in max_index:
+        hyp = torch.unique_consecutive(hyp)
+        hyp = hyp[hyp > 0].tolist()
+        ret_hyps.append(hyp)
+    return ret_hyps
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    batch: dict,
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if no rescoring is used, the key is the string `no_rescore`.
+               If LM rescoring is used, the key is the string `lm_scale_xxx`,
+               where `xxx` is the value of `lm_scale`. An example key is
+               `lm_scale_0.7`
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+
+        - params.method is "1best", it uses 1best decoding without LM rescoring.
+        - params.method is "nbest", it uses nbest decoding without LM rescoring.
+        - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+        - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+          rescoring.
+
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      sos_id:
+        The token ID of the SOS.
+      eos_id:
+        The token ID of the EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict. Note: If it decodes to nothing, then return None.
+    """
+    if HLG is not None:
+        device = HLG.device
+    else:
+        device = H.device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+
+    nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+    # nnet_output is (N, T, C)
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            torch.div(
+                supervisions["start_frame"],
+                params.subsampling_factor,
+                rounding_mode="floor",
+            ),
+            torch.div(
+                supervisions["num_frames"],
+                params.subsampling_factor,
+                rounding_mode="floor",
+            ),
+        ),
+        1,
+    ).to(torch.int32)
+
+    if H is None:
+        assert HLG is not None
+        decoding_graph = HLG
+    else:
+        assert HLG is None
+        assert bpe_model is not None
+        decoding_graph = H
+
+    lattice = get_lattice(
+        nnet_output=nnet_output,
+        decoding_graph=decoding_graph,
+        supervision_segments=supervision_segments,
+        search_beam=params.search_beam,
+        output_beam=params.output_beam,
+        min_active_states=params.min_active_states,
+        max_active_states=params.max_active_states,
+        subsampling_factor=params.subsampling_factor,
+    )
+
+    if params.method == "ctc-decoding":
+        best_path = one_best_decoding(
+            lattice=lattice, use_double_scores=params.use_double_scores
+        )
+        # Note: `best_path.aux_labels` contains token IDs, not word IDs
+        # since we are using H, not HLG here.
+        #
+        # token_ids is a lit-of-list of IDs
+        token_ids = get_texts(best_path)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(token_ids)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        unk = bpe_model.decode(bpe_model.unk_id()).strip()
+        hyps = [[w for w in s.split() if w != unk] for s in hyps]
+        key = "ctc-decoding"
+
+        return {key: hyps}
+
+    if params.method == "ctc-greedy-search":
+        hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask)
+
+        # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+        hyps = bpe_model.decode(hyps)
+
+        # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+        unk = bpe_model.decode(bpe_model.unk_id()).strip()
+        hyps = [[w for w in s.split() if w != unk] for s in hyps]
+        key = "ctc-greedy-search"
+
+        return {key: hyps}
+
+    if params.method == "nbest-oracle":
+        # Note: You can also pass rescored lattices to it.
+        # We choose the HLG decoded lattice for speed reasons
+        # as HLG decoding is faster and the oracle WER
+        # is only slightly worse than that of rescored lattices.
+        best_path = nbest_oracle(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            ref_texts=supervisions["text"],
+            word_table=word_table,
+            nbest_scale=params.nbest_scale,
+            oov="",
+        )
+        hyps = get_texts(best_path)
+        hyps = [
+            [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+        ]
+        key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}"  # noqa
+        return {key: hyps}
+
+    if params.method == "nbest":
+        best_path = nbest_decoding(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            use_double_scores=params.use_double_scores,
+            nbest_scale=params.nbest_scale,
+        )
+        key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}"  # noqa
+
+        hyps = get_texts(best_path)
+        hyps = [
+            [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+        ]
+        return {key: hyps}
+
+    assert params.method in [
+        "1best",
+        "nbest-rescoring",
+        "whole-lattice-rescoring",
+        "attention-decoder",
+    ]
+
+    lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+    lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+    lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+    if params.method == "1best":
+        best_path_dict = one_best_decoding(
+            lattice=lattice,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "nbest-rescoring":
+        best_path_dict = rescore_with_n_best_list(
+            lattice=lattice,
+            G=G,
+            num_paths=params.num_paths,
+            lm_scale_list=lm_scale_list,
+            nbest_scale=params.nbest_scale,
+        )
+    elif params.method == "whole-lattice-rescoring":
+        best_path_dict = rescore_with_whole_lattice(
+            lattice=lattice,
+            G_with_epsilon_loops=G,
+            lm_scale_list=lm_scale_list,
+        )
+    elif params.method == "attention-decoder":
+        best_path_dict = rescore_with_attention_decoder(
+            lattice=lattice,
+            num_paths=params.num_paths,
+            model=model,
+            memory=memory,
+            memory_key_padding_mask=memory_key_padding_mask,
+            sos_id=sos_id,
+            eos_id=eos_id,
+            nbest_scale=params.nbest_scale,
+        )
+    else:
+        raise ValueError(f"Unsupported decoding method: {params.method}")
+
+    ans = dict()
+    if best_path_dict is not None:
+        for lm_scale_str, best_path in best_path_dict.items():
+            hyps = get_texts(best_path)
+            hyps = [
+                [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps
+            ]
+            ans[lm_scale_str] = hyps
+    else:
+        ans = None
+    return ans
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    HLG: Optional[k2.Fsa],
+    H: Optional[k2.Fsa],
+    bpe_model: Optional[spm.SentencePieceProcessor],
+    word_table: k2.SymbolTable,
+    sos_id: int,
+    eos_id: int,
+    G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      HLG:
+        The decoding graph. Used only when params.method is NOT ctc-decoding.
+      H:
+        The ctc topo. Used only when params.method is ctc-decoding.
+      bpe_model:
+        The BPE model. Used only when params.method is ctc-decoding.
+      word_table:
+        It is the word symbol table.
+      sos_id:
+        The token ID for SOS.
+      eos_id:
+        The token ID for EOS.
+      G:
+        An LM. It is not None when params.method is "nbest-rescoring"
+        or "whole-lattice-rescoring". In general, the G in HLG
+        is a 3-gram LM, while this G is a 4-gram LM.
+    Returns:
+      Return a dict, whose key may be "no-rescore" if no LM rescoring
+      is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            batch=batch,
+            word_table=word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        if hyps_dict is not None:
+            for lm_scale, hyps in hyps_dict.items():
+                this_batch = []
+                assert len(hyps) == len(texts)
+                for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                    ref_words = ref_text.split()
+                    this_batch.append((cut_id, ref_words, hyp_words))
+
+                results[lm_scale].extend(this_batch)
+        else:
+            assert len(results) > 0, "It should not decode to empty in the first batch!"
+            this_batch = []
+            hyp_words = []
+            for ref_text in texts:
+                ref_words = ref_text.split()
+                this_batch.append((ref_words, hyp_words))
+
+            for lm_scale in results.keys():
+                results[lm_scale].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % 100 == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+) -> None:
+    if params.method == "attention-decoder":
+        # Set it to False since there are too many logs.
+        enable_log = False
+    else:
+        enable_log = True
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        if enable_log:
+            logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=enable_log
+            )
+            test_set_wers[key] = wer
+
+        if enable_log:
+            logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main() -> None:
+    parser = get_parser()
+    TedLiumAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+    args.lm_path = Path(args.lm_path)
+    args.result_dir = Path(args.result_dir)
+
+    args.result_dir.mkdir(exist_ok=True)
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+    logging.info("Decoding started")
+    logging.info(params)
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+    sos_id = graph_compiler.sos_id
+    eos_id = graph_compiler.eos_id
+
+    if params.method in ("ctc-decoding", "ctc-greedy-search"):
+        HLG = None
+        H = k2.ctc_topo(
+            max_token=max_token_id,
+            modified=False,
+            device=device,
+        )
+        bpe_model = spm.SentencePieceProcessor()
+        bpe_model.load(str(params.lang_dir / "bpe.model"))
+    else:
+        H = None
+        bpe_model = None
+        HLG = k2.Fsa.from_dict(
+            torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+        )
+        assert HLG.requires_grad is False
+
+        if not hasattr(HLG, "lm_scores"):
+            HLG.lm_scores = HLG.scores.clone()
+
+    if params.method in ("nbest-rescoring", "whole-lattice-rescoring"):
+        assert params.lm_path.suffix in (".pt", ".txt")
+
+        if params.lm_path.is_file() and params.lm_path.suffix == ".pt":
+            logging.info(f"Loading pre-compiled {params.lm_path.name}")
+            d = torch.load(params.lm_path, map_location=device)
+            G = k2.Fsa.from_dict(d)
+        elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt":
+            raise FileNotFoundError(f"No such language model file: '{params.lm_path}'")
+        else:
+            # here we pass only if LM filename ends with '.pt' and doesn't exist
+            # or if LM filename ends '.txt' and exists.
+            if (
+                not params.lm_path.is_file()
+                and params.lm_path.suffix == ".pt"
+                and not (
+                    params.lm_path.parent / f"{params.lm_path.stem}.fst.txt"
+                ).is_file()
+            ):
+                raise FileNotFoundError(
+                    f"No such language model file: '{params.lm_path}'\n"
+                    "'.fst.txt' representation of the language model was "
+                    "not found either."
+                )
+            else:
+                # whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt
+                # we are going to load lm_name.fst.txt here
+                params.lm_path = params.lm_path.parent / params.lm_path.name.replace(
+                    ".pt", ".fst.txt"
+                )
+                logging.info(f"Loading {params.lm_path.name}")
+                logging.warning("It may take 8 minutes.")
+                with open(params.lm_path) as f:
+                    first_word_disambig_id = lexicon.word_table["#0"]
+
+                    G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+                    # G.aux_labels is not needed in later computations, so
+                    # remove it here.
+                    del G.aux_labels
+                    # CAUTION: The following line is crucial.
+                    # Arcs entering the back-off state have label equal to #0.
+                    # We have to change it to 0 here.
+                    G.labels[G.labels >= first_word_disambig_id] = 0
+                    # See https://github.com/k2-fsa/k2/issues/874
+                    # for why we need to set G.properties to None
+                    G.__dict__["_properties"] = None
+                    G = k2.Fsa.from_fsas([G]).to(device)
+                    G = k2.arc_sort(G)
+                    # Save a dummy value so that it can be loaded in C++.
+                    # See https://github.com/pytorch/pytorch/issues/67902
+                    # for why we need to do this.
+                    G.dummy = 1
+
+                    torch.save(
+                        G.as_dict(),
+                        params.lm_path.parent
+                        / params.lm_path.name.replace(".fst.txt", ".pt"),
+                    )
+
+        if params.method == "whole-lattice-rescoring":
+            # Add epsilon self-loops to G as we will compose
+            # it with the whole lattice later
+            G = k2.add_epsilon_self_loops(G)
+            G = k2.arc_sort(G)
+            G = G.to(device)
+
+        # G.lm_scores is used to replace HLG.lm_scores during
+        # LM rescoring.
+        G.lm_scores = G.scores.clone()
+    else:
+        G = None
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    tedlium = TedLiumAsrDataModule(args)
+
+    valid_cuts = tedlium.dev_cuts()
+    test_cuts = tedlium.test_cuts()
+
+    valid_dl = tedlium.valid_dataloaders(valid_cuts)
+    test_dl = tedlium.test_dataloaders(test_cuts)
+
+    test_sets = ["dev", "test"]
+    test_dls = [valid_dl, test_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dls):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            HLG=HLG,
+            H=H,
+            bpe_model=bpe_model,
+            word_table=lexicon.word_table,
+            G=G,
+            sos_id=sos_id,
+            eos_id=eos_id,
+        )
+
+        save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+    logging.info("Done!")
+
+
+torch.set_num_threads(1)
+# when we import add_model_arguments from train.py
+# we enforce torch.set_num_interop_threads(1) in it,
+# so we ended up with setting num_interop_threads to one
+# two times: in train.py and decode.py which cause an error,
+# that is why added an additional if statement.
+if torch.get_num_interop_threads() != 1:
+    torch.set_num_interop_threads(1)
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+# in PyTorch 1.12 and later.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py
new file mode 100755
index 000000000..009bea230
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/export.py
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+#
+# Copyright 2022 Behavox LLC (Author: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./conformer_ctc2/export.py \
+  --exp-dir ./conformer_ctc2/exp \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `conformer_ctc2/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/tedlium3/ASR
+    ./conformer_ctc2/decode.py \
+        --exp-dir ./conformer_ctc2/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 100
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from conformer import Conformer
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, str2bool
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help=(
+            "Number of checkpoints to average. Automatically select "
+            "consecutive checkpoints before the checkpoint specified by "
+            "'--epoch' and '--iter'"
+        ),
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help=(
+            "Whether to load averaged model. Currently it only supports "
+            "using --epoch. If True, it would decode with the averaged model "
+            "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+            "Actually only the models with epoch number of `epoch-avg` and "
+            "`epoch` are loaded for averaging. "
+        ),
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc2/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=True,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+    """
+    # parameters for conformer
+    params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80})
+    return params
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info(params)
+
+    logging.info("About to create model")
+
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                "Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        convert_scaled_to_non_scaled(model, inplace=True)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/lstmp.py b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/optim.py b/egs/tedlium3/ASR/conformer_ctc2/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling.py b/egs/tedlium3/ASR/conformer_ctc2/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/subsampling.py b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py
new file mode 120000
index 000000000..8c91f2336
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc2/subsampling.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py
new file mode 100755
index 000000000..42e4c010a
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/train.py
@@ -0,0 +1,1061 @@
+#!/usr/bin/env python3
+# Copyright    2022  Behavox LLC.        (authors: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir conformer_ctc/exp \
+  --max-duration 300
+
+# For mix precision training:
+
+./conformer_ctc/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir conformer_ctc/exp \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+from asr_datamodule import TedLiumAsrDataModule
+from conformer import Conformer
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    display_and_save_batch,
+    encode_supervisions,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser) -> None:
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=24,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--num-decoder-layers",
+        type=int,
+        default=6,
+        help="""Number of decoder layer of transformer decoder.
+        Setting this to 0 will not create the decoder at all (pure CTC model)
+        """,
+    )
+
+    parser.add_argument(
+        "--att-rate",
+        type=float,
+        default=0.8,
+        help="""The attention rate.
+        The total loss is (1 -  att_rate) * ctc_loss + att_rate * att_loss
+        """,
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=1536,
+        help="Feedforward module dimension of the conformer model.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer multiheadattention modules.",
+    )
+
+    parser.add_argument(
+        "--dim-model",
+        type=int,
+        default=384,
+        help="Attention dimension in the conformer model.",
+    )
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="conformer_ctc/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_bpe_500",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt" and "bpe.model"
+        """,
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="Number of epochs that affects how rapidly the learning rate decreases.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=4000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 10,
+            "reset_interval": 200,
+            "valid_interval": 1000,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for ctc loss
+            "beam_size": 10,
+            "reduction": "none",
+            "use_double_scores": True,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: torch.nn.Module,
+    model_avg: torch.nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that is used for training.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    model_avg: Optional[torch.nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used for training.
+      scheduler:
+        The learning rate scheduler used for training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute CTC loss given the model and its inputs.
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      graph_compiler:
+        It is used to build a decoding graph from a ctc topo and training
+        transcript. The training transcript is contained in the given `batch`,
+        while the ctc topo is built when this compiler is instantiated.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    with torch.set_grad_enabled(is_training):
+        nnet_output, encoder_memory, memory_mask = model(
+            feature, supervisions, warmup=warmup
+        )
+
+        supervision_segments, texts = encode_supervisions(
+            supervisions, subsampling_factor=params.subsampling_factor
+        )
+
+        token_ids = convert_texts_into_ids(texts, graph_compiler.sp)
+        decoding_graph = graph_compiler.compile(token_ids)
+
+        dense_fsa_vec = k2.DenseFsaVec(
+            nnet_output,
+            supervision_segments,
+            allow_truncate=params.subsampling_factor - 1,
+        )
+
+        ctc_loss = k2.ctc_loss(
+            decoding_graph=decoding_graph,
+            dense_fsa_vec=dense_fsa_vec,
+            output_beam=params.beam_size,
+            reduction=params.reduction,
+            use_double_scores=params.use_double_scores,
+        )
+
+        if params.att_rate > 0.0:
+            with torch.set_grad_enabled(is_training):
+                mmodel = model.module if hasattr(model, "module") else model
+                # Note: We need to generate an unsorted version of token_ids
+                # `encode_supervisions()` called above sorts text, but
+                # encoder_memory and memory_mask are not sorted, so we
+                # use an unsorted version `supervisions["text"]` to regenerate
+                # the token_ids
+                #
+                # See https://github.com/k2-fsa/icefall/issues/97
+                # for more details
+                unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
+                att_loss = mmodel.decoder_forward(
+                    encoder_memory,
+                    memory_mask,
+                    token_ids=unsorted_token_ids,
+                    sos_id=graph_compiler.sos_id,
+                    eos_id=graph_compiler.eos_id,
+                    warmup=warmup,
+                )
+        else:
+            att_loss = torch.tensor([0])
+
+        ctc_loss_is_finite = torch.isfinite(ctc_loss)
+        att_loss_is_finite = torch.isfinite(att_loss)
+        if torch.any(~ctc_loss_is_finite) or torch.any(~att_loss_is_finite):
+            logging.info(
+                "Not all losses are finite!\n"
+                f"ctc_loss: {ctc_loss}\n"
+                f"att_loss: {att_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            ctc_loss = ctc_loss[ctc_loss_is_finite]
+            att_loss = att_loss[att_loss_is_finite]
+
+            # If the batch contains more than 10 utterances AND
+            # if either all ctc_loss or att_loss is inf or nan,
+            # we stop the training process by raising an exception
+            if torch.all(~ctc_loss_is_finite) or torch.all(~att_loss_is_finite):
+                raise ValueError(
+                    "There are too many utterances in this batch "
+                    "leading to inf or nan losses."
+                )
+
+        ctc_loss = ctc_loss.sum()
+        att_loss = att_loss.sum()
+
+        if params.att_rate > 0.0:
+            loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+        else:
+            loss = ctc_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    # info["frames"] is an approximate number for two reasons:
+    # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+    # (2) If some utterances in the batch lead to inf/nan loss, they
+    #     are filtered out.
+    info["frames"] = (
+        torch.div(feature_lens, params.subsampling_factor, rounding_mode="floor")
+        .sum()
+        .item()
+    )
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["ctc_loss"] = ctc_loss.detach().cpu().item()
+    if params.att_rate > 0.0:
+        info["att_loss"] = att_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch in valid_dl:
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[torch.nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[torch.nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      graph_compiler:
+        It is used to convert transcripts to FSAs.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    lexicon = Lexicon(params.lang_dir)
+    max_token_id = max(lexicon.tokens)
+    num_classes = max_token_id + 1  # +1 for the blank
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    if "lang_bpe" not in str(params.lang_dir):
+        raise ValueError(
+            f"Unsupported type of lang dir (we expected it to have "
+            f"'lang_bpe' in its name): {params.lang_dir}"
+        )
+
+    graph_compiler = BpeCtcTrainingGraphCompiler(
+        params.lang_dir,
+        device=device,
+        sos_token="",
+        eos_token="",
+    )
+
+    logging.info("About to create model")
+    model = Conformer(
+        num_features=params.feature_dim,
+        num_classes=num_classes,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.dim_model,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        num_decoder_layers=params.num_decoder_layers,
+    )
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[torch.nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = optim.Eve(model.parameters(), lr=params.initial_lr)
+    scheduler = optim.Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and checkpoints.get("optimizer") is not None:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if checkpoints and checkpoints.get("scheduler") is not None:
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    tedlium = TedLiumAsrDataModule(args)
+
+    train_cuts = tedlium.train_cuts()
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = tedlium.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = tedlium.dev_cuts()
+    valid_dl = tedlium.valid_dataloaders(valid_cuts)
+
+    if (
+        params.start_epoch <= 1
+        and params.start_batch <= 0
+        and not params.print_diagnostics
+    ):
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+        train_dl.dataset.epoch = epoch - 1
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[torch.nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: BpeCtcTrainingGraphCompiler,
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=graph_compiler.sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    TedLiumAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+# in PyTorch 1.12 and later.
+torch.backends.cuda.matmul.allow_tf32 = True
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py
new file mode 100644
index 000000000..9dbf32e48
--- /dev/null
+++ b/egs/tedlium3/ASR/conformer_ctc2/transformer.py
@@ -0,0 +1,1093 @@
+# Copyright    2021  University of Chinese Academy of Sciences (author: Han Zhu)
+# Copyright    2022  Xiaomi Corp.                              (author: Quandong Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from attention import MultiheadAttention
+from combiner import RandomCombine
+from label_smoothing import LabelSmoothingLoss
+from scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledEmbedding,
+    ScaledLinear,
+)
+from subsampling import Conv2dSubsampling
+from torch.nn.utils.rnn import pad_sequence
+
+# Note: TorchScript requires Dict/List/etc. to be fully typed.
+Supervisions = Dict[str, torch.Tensor]
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        num_features: int,
+        num_classes: int,
+        subsampling_factor: int = 4,
+        d_model: int = 256,
+        nhead: int = 4,
+        dim_feedforward: int = 2048,
+        num_encoder_layers: int = 12,
+        num_decoder_layers: int = 6,
+        dropout: float = 0.1,
+        layer_dropout: float = 0.075,
+        aux_layer_period: int = 3,
+    ) -> None:
+        """
+        Args:
+          num_features:
+            the input dimension of the model.
+          num_classes:
+            the output dimension of the model.
+          subsampling_factor:
+            number of output frames is num_in_frames // subsampling_factor;
+            currently, subsampling_factor MUST be 4.
+          d_model:
+            attention dimension.
+          nhead:
+            number of heads in multi-head attention;
+            must satisfy d_model // nhead == 0.
+          dim_feedforward:
+            the output dimension of the feedforward layers in encoder/decoder.
+          num_encoder_layers:
+            number of encoder layers.
+          num_decoder_layers:
+            number of decoder layers.
+          dropout:
+            dropout in encoder/decoder.
+          layer_dropout:
+            layer-dropout rate.
+          aux_layer_period:
+            determines the auxiliary encoder layers.
+        """
+        super().__init__()
+
+        self.num_features = num_features
+        self.num_classes = num_classes
+        self.subsampling_factor = subsampling_factor
+        if subsampling_factor != 4:
+            raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+        # self.encoder_embed converts the input of shape (N, T, num_classes)
+        # to the shape (N, T//subsampling_factor, d_model).
+        # That is, it does two things simultaneously:
+        #   (1) subsampling: T -> T//subsampling_factor
+        #   (2) embedding: num_classes -> d_model
+        self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+        self.encoder_pos = PositionalEncoding(d_model, dropout)
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            dropout=dropout,
+            layer_dropout=layer_dropout,
+        )
+        # aux_layers from 1/3
+        self.encoder = TransformerEncoder(
+            encoder_layer=encoder_layer,
+            num_layers=num_encoder_layers,
+            aux_layers=list(
+                range(
+                    num_encoder_layers // 3,
+                    num_encoder_layers - 1,
+                    aux_layer_period,
+                )
+            ),
+        )
+
+        # TODO(fangjun): remove dropout
+        self.encoder_output_layer = nn.Sequential(
+            nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True)
+        )
+
+        if num_decoder_layers > 0:
+            self.decoder_num_class = (
+                self.num_classes
+            )  # bpe model already has sos/eos symbol
+
+            self.decoder_embed = ScaledEmbedding(
+                num_embeddings=self.decoder_num_class, embedding_dim=d_model
+            )
+            self.decoder_pos = PositionalEncoding(d_model, dropout)
+
+            decoder_layer = TransformerDecoderLayer(
+                d_model=d_model,
+                nhead=nhead,
+                dim_feedforward=dim_feedforward,
+                dropout=dropout,
+            )
+
+            self.decoder = TransformerDecoder(
+                decoder_layer=decoder_layer,
+                num_layers=num_decoder_layers,
+                aux_layers=[],
+            )
+
+            self.decoder_output_layer = ScaledLinear(
+                d_model, self.decoder_num_class, bias=True
+            )
+
+            self.decoder_criterion = LabelSmoothingLoss(reduction="none")
+        else:
+            self.decoder_criterion = None
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        supervision: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+          x:
+            The input tensor. Its shape is (N, S, C).
+          supervision:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            (CAUTION: It contains length information, i.e., start and number of
+             frames, before subsampling)
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          Return a tuple containing 3 tensors:
+            - CTC output for ctc decoding. Its shape is (N, S, C)
+            - Encoder output with shape (S, N, C). It can be used as key and
+              value for the decoder.
+            - Encoder output padding mask. It can be used as
+              memory_key_padding_mask for the decoder. Its shape is (N, S).
+              It is None if `supervision` is None.
+        """
+
+        encoder_memory, memory_key_padding_mask = self.run_encoder(
+            x, supervision, warmup
+        )
+
+        x = self.ctc_output(encoder_memory)
+        return x, encoder_memory, memory_key_padding_mask
+
+    def run_encoder(
+        self,
+        x: torch.Tensor,
+        supervisions: Optional[Supervisions] = None,
+        warmup: float = 1.0,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Run the transformer encoder.
+
+        Args:
+          x:
+            The model input. Its shape is (N, S, C).
+          supervisions:
+            Supervision in lhotse format.
+            See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+            CAUTION: It contains length information, i.e., start and number of
+            frames, before subsampling
+            It is read directly from the batch, without any sorting. It is used
+            to compute the encoder padding mask, which is used as memory key
+            padding mask for the decoder.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          Return a tuple with two tensors:
+            - The encoder output, with shape (S, N, C)
+            - encoder padding mask, with shape (N, S).
+              The mask is None if `supervisions` is None.
+              It is used as memory key padding mask in the decoder.
+        """
+        x = self.encoder_embed(x)
+        x = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)  # (N, S, C) -> (S, N, C)
+        mask = encoder_padding_mask(x.size(0), supervisions)
+        mask = mask.to(x.device) if mask is not None else None
+        x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup)  # (S, N, C)
+
+        return x, mask
+
+    def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+          x:
+            the output tensor from the transformer encoder;
+            its shape is (S, N, C)
+
+        Returns:
+          Return a tensor that can be used for CTC decoding.
+          Its shape is (N, S, C)
+        """
+        x = self.encoder_output_layer(x)
+        x = x.permute(1, 0, 2)  # (S, N, C) -> (N, S, C)
+        x = nn.functional.log_softmax(x, dim=-1)  # (N, S, C)
+        return x
+
+    @torch.jit.export
+    def decoder_forward(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[List[int]],
+        sos_id: int,
+        eos_id: int,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder of shape (S, N, C)
+          memory_key_padding_mask:
+            The padding mask from the encoder of shape (N, S).
+          token_ids:
+            A list-of-list IDs. Each sublist contains IDs for an utterance.
+            The IDs can be either phone IDs or word piece IDs.
+          sos_id:
+            sos token id
+          eos_id:
+            eos token id
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          A scalar, the **sum** of label smoothing loss over utterances
+          in the batch without any normalization.
+        """
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device)
+        ys_out_pad = ys_out_pad.to(device)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            warmup=warmup,
+        )  # (T, N, C)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+
+        decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
+
+        return decoder_loss
+
+    @torch.jit.export
+    def decoder_nll(
+        self,
+        memory: torch.Tensor,
+        memory_key_padding_mask: torch.Tensor,
+        token_ids: List[torch.Tensor],
+        sos_id: int,
+        eos_id: int,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Args:
+          memory:
+            It's the output of the encoder of shape (S, N, C).
+          memory_key_padding_mask:
+            The padding mask from the encoder of shape (N, S).
+          token_ids:
+            A list-of-list IDs (e.g., word piece IDs).
+            Each sublist represents an utterance.
+          sos_id:
+            The token ID for SOS.
+          eos_id:
+            The token ID for EOS.
+          warmup:
+            a floating point value that gradually increases from 0 throughout
+            training; when it is >= 1.0 we are "fully warmed up". It is used
+            to turn modules on sequentially.
+
+        Returns:
+          A 2-D tensor of shape (len(token_ids), max_token_length)
+          representing the cross entropy loss (i.e., negative log-likelihood).
+        """
+        # The common part between this function and decoder_forward could be
+        # extracted as a separate function.
+        if isinstance(token_ids[0], torch.Tensor):
+            # This branch is executed by torchscript in C++.
+            # See https://github.com/k2-fsa/k2/pull/870
+            # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
+            token_ids = [tolist(t) for t in token_ids]
+
+        ys_in = add_sos(token_ids, sos_id=sos_id)
+        ys_in = [torch.tensor(y) for y in ys_in]
+        ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
+
+        ys_out = add_eos(token_ids, eos_id=eos_id)
+        ys_out = [torch.tensor(y) for y in ys_out]
+        ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
+
+        device = memory.device
+        ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
+        ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
+
+        tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
+
+        tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+        # TODO: Use length information to create the decoder padding mask
+        # We set the first column to False since the first column in ys_in_pad
+        # contains sos_id, which is the same as eos_id in our current setting.
+        tgt_key_padding_mask[:, 0] = False
+
+        tgt = self.decoder_embed(ys_in_pad)  # (N, T) -> (N, T, C)
+        tgt = self.decoder_pos(tgt)
+        tgt = tgt.permute(1, 0, 2)  # (N, T, С) -> (T, N, C)
+        pred_pad = self.decoder(
+            tgt=tgt,
+            memory=memory,
+            tgt_mask=tgt_mask,
+            tgt_key_padding_mask=tgt_key_padding_mask,
+            memory_key_padding_mask=memory_key_padding_mask,
+            warmup=warmup,
+        )  # (T, B, F)
+        pred_pad = pred_pad.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
+        pred_pad = self.decoder_output_layer(pred_pad)  # (N, T, C)
+        # nll: negative log-likelihood
+        nll = torch.nn.functional.cross_entropy(
+            pred_pad.view(-1, self.decoder_num_class),
+            ys_out_pad.view(-1),
+            ignore_index=-1,
+            reduction="none",
+        )
+
+        nll = nll.view(pred_pad.shape[0], -1)
+
+        return nll
+
+
+class TransformerEncoderLayer(nn.Module):
+    """
+    Modified from torch.nn.TransformerEncoderLayer.
+
+    Example:
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = encoder_layer(src)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+    ) -> None:
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+        """
+
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = MultiheadAttention(d_model, nhead)
+        # Implementation of Feedforward model
+
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        src_mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+          src:
+            the sequence to the encoder layer of shape (S, N, C) (required).
+          src_mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional)
+          warmup:
+            controls selective bypass of layers; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        src_orig = src
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+        src = src + self.dropout(src_att)
+
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        if alpha != 1.0:
+            src = alpha * src + (1.0 - alpha) * src_orig
+
+        return src
+
+
+class TransformerDecoderLayer(nn.Module):
+    """Modified from torch.nn.TransformerDecoderLayer.
+
+    Example:
+        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = decoder_layer(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        bypass_scale: float = 0.1,
+        layer_dropout: float = 0.075,
+    ) -> None:
+
+        """
+        Args:
+          d_model:
+            the number of expected features in the input (required).
+          nhead:
+            the number of heads in the multiheadattention models (required).
+          dim_feedforward:
+            the dimension of the feedforward network model (default=2048).
+          dropout:
+            the dropout value (default=0.1).
+          bypass_scale:
+            a scale on the layer's output, used in bypass (resnet-type) skip-connection;
+            when the layer is bypassed, the final output will be a
+            weighted sum of the layer's input and layer's output with weights
+            (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1).
+          layer_dropout:
+            the probability to bypass the layer (default=0.075).
+        """
+
+        super().__init__()
+
+        if bypass_scale < 0.0 or bypass_scale > 1.0:
+            raise ValueError("bypass_scale should be between 0.0 and 1.0")
+
+        if layer_dropout < 0.0 or layer_dropout > 1.0:
+            raise ValueError("layer_dropout should be between 0.0 and 1.0")
+
+        self.bypass_scale = bypass_scale
+        self.layer_dropout = layer_dropout
+
+        self.self_attn = MultiheadAttention(d_model, nhead)
+        self.src_attn = MultiheadAttention(d_model, nhead)
+
+        # Implementation of Feedforward model
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        # try to ensure the output is close to zero-mean (or at least, zero-median).
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the inputs (and mask) through the decoder layer.
+
+        Args:
+          tgt:
+            the sequence to the decoder layer of shape (T, N, C) (required).
+          memory:
+            the sequence from the last layer of the encoder of shape (S, N, C) (required).
+          tgt_mask:
+            the mask for the tgt sequence of shape (T, T) (optional).
+          memory_mask:
+            the mask for the memory sequence of shape (T, S) (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch of shape (N, T) (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch of shape (N, S) (optional).
+          warmup: controls selective bypass of layers; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (T, N, C), where
+          S is the source sequence length,
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        tgt_orig = tgt
+
+        warmup_scale = min(self.bypass_scale + warmup, 1.0)
+        # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+        # completely bypass it.
+        if self.training:
+            alpha = (
+                warmup_scale
+                if torch.rand(()).item() <= (1.0 - self.layer_dropout)
+                else self.bypass_scale
+            )
+        else:
+            alpha = 1.0
+
+        tgt_att = self.self_attn(
+            tgt,
+            tgt,
+            tgt,
+            attn_mask=tgt_mask,
+            key_padding_mask=tgt_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout(tgt_att)
+
+        src_att = self.src_attn(
+            tgt,
+            memory,
+            memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout(src_att)
+
+        tgt = tgt + self.dropout(self.feed_forward(tgt))
+
+        tgt = self.norm_final(self.balancer(tgt))
+
+        if alpha != 1.0:
+            tgt = alpha * tgt + (1.0 - alpha) * tgt_orig
+
+        return tgt
+
+
+class TransformerEncoder(nn.Module):
+    """TransformerEncoder is a stack of N encoder layers
+
+    Examples:
+        >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+        >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = transformer_encoder(src)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+        """
+        Args:
+          encoder_layer:
+            an instance of the TransformerEncoderLayer() class (required).
+          num_layers:
+            the number of sub-encoder-layers in the encoder (required).
+          aux_layers:
+            list of indexes of sub-encoder-layers outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the input through the encoder layers in turn.
+
+        Args:
+          src:
+            the input to the encoder of shape (S, N, C) (required).
+          mask:
+            the mask for the src sequence of shape (S, S) (optional).
+          src_key_padding_mask:
+            the mask for the src keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (S, N, C), where
+          S is the source sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = src
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class TransformerDecoder(nn.Module):
+    """TransformerDecoder is a stack of N decoder layers
+
+    Examples:
+        >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
+        >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
+        >>> memory = torch.rand(10, 32, 512)
+        >>> tgt = torch.rand(20, 32, 512)
+        >>> out = transformer_decoder(tgt, memory)
+    """
+
+    def __init__(
+        self,
+        decoder_layer: nn.Module,
+        num_layers: int,
+        aux_layers: List[int],
+    ) -> None:
+        """
+        Args:
+          decoder_layer:
+            an instance of the TransformerDecoderLayer() class (required).
+          num_layers:
+            the number of decoder layers in the decoder (required).
+          aux_layers:
+            list of indexes of decoder layer outputs to be combined (required).
+        """
+
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(decoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert len(set(aux_layers)) == len(aux_layers)
+
+        assert num_layers - 1 not in aux_layers
+        self.aux_layers = aux_layers + [num_layers - 1]
+
+        self.combiner = RandomCombine(
+            num_inputs=len(self.aux_layers),
+            final_weight=0.5,
+            pure_prob=0.333,
+            stddev=2.0,
+        )
+
+    def forward(
+        self,
+        tgt: torch.Tensor,
+        memory: torch.Tensor,
+        tgt_mask: Optional[torch.Tensor] = None,
+        memory_mask: Optional[torch.Tensor] = None,
+        tgt_key_padding_mask: Optional[torch.Tensor] = None,
+        memory_key_padding_mask: Optional[torch.Tensor] = None,
+        warmup: float = 1.0,
+    ) -> torch.Tensor:
+        """Pass the input (and mask) through the decoder layers in turn.
+
+        Args:
+          tgt:
+            the sequence to the decoder of shape (T, N, C) (required).
+          memory:
+            the sequence from the last layer of the encoder of shape (S, N, C) (required).
+          tgt_mask:
+            the mask for the tgt sequence of shape (T, T) (optional).
+          memory_mask:
+            the mask for the memory sequence of shape (T, S) (optional).
+          tgt_key_padding_mask:
+            the mask for the tgt keys per batch of shape (N, T)  (optional).
+          memory_key_padding_mask:
+            the mask for the memory keys per batch of shape (N, S) (optional).
+          warmup:
+            controls selective bypass of layer; if < 1.0, we will
+            bypass the layer more frequently (default=1.0).
+
+        Returns:
+          Output tensor of the shape (T, N, C), where
+          S is the source sequence length,
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        output = tgt
+
+        outputs = []
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                memory,
+                tgt_mask=tgt_mask,
+                memory_mask=memory_mask,
+                tgt_key_padding_mask=tgt_key_padding_mask,
+                memory_key_padding_mask=memory_key_padding_mask,
+                warmup=warmup,
+            )
+
+            if i in self.aux_layers:
+                outputs.append(output)
+
+        output = self.combiner(outputs)
+
+        return output
+
+
+class PositionalEncoding(nn.Module):
+    """This class implements the positional encoding
+    proposed in the following paper:
+
+    - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
+
+        PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
+        PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
+
+    Note:
+
+      1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
+                               = exp(-1* 2i / d_model * log(100000))
+                               = exp(2i * -(log(10000) / d_model))
+    """
+
+    def __init__(self, d_model: int, dropout: float = 0.1) -> None:
+        """
+        Args:
+          d_model: Embedding dimension.
+          dropout: Dropout probability to be applied to the output of this module.
+        """
+        super().__init__()
+        self.d_model = d_model
+        self.xscale = math.sqrt(self.d_model)
+        self.dropout = nn.Dropout(p=dropout)
+        # not doing: self.pe = None because of errors thrown by torchscript
+        self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
+
+    def extend_pe(self, x: torch.Tensor) -> None:
+        """Extend the time t in the positional encoding if required.
+        The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+        is (N, T, d_model). If T > T1, then we change the shape of self.pe
+        to (N, T, d_model). Otherwise, nothing is done.
+
+        Args:
+          x:
+            It is a tensor of shape (N, T, C).
+            T is the target sequence length,
+            N is the batch size,
+            C is the feature number.
+        """
+        if self.pe is not None:
+            if self.pe.size(1) >= x.size(1):
+                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        # Now pe is of shape (1, T, d_model), where T is x.size(1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Add positional encoding.
+
+        Args:
+          x: Input of shape is (N, T, C)
+
+        Returns:
+          A tensor of the same shape (N, T, C),
+          T is the target sequence length,
+          N is the batch size,
+          C is the feature number.
+
+        """
+        self.extend_pe(x)
+        x = x + self.pe[:, : x.size(1), :]
+        return self.dropout(x)
+
+
+def encoder_padding_mask(
+    max_len: int, supervisions: Optional[Supervisions] = None
+) -> Optional[torch.Tensor]:
+    """Make mask tensor containing indexes of padded part.
+
+    TODO:
+      This function **assumes** that the model uses
+      a subsampling factor of 4. We should remove that
+      assumption later.
+
+    Args:
+      max_len:
+        Maximum length of input features.
+        CAUTION: It is the length after subsampling.
+      supervisions:
+        Supervision in lhotse format.
+        See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32  # noqa
+        (CAUTION: It contains length information, i.e., start and number of
+         frames, before subsampling)
+
+    Returns:
+      Mask tensor of dimension (batch_size, input_length),
+      True denotes the masked indices.
+    """
+    if supervisions is None:
+        return None
+
+    supervision_segments = torch.stack(
+        (
+            supervisions["sequence_idx"],
+            supervisions["start_frame"],
+            supervisions["num_frames"],
+        ),
+        1,
+    ).to(torch.int32)
+
+    lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
+    for idx in range(supervision_segments.size(0)):
+        # Note: TorchScript doesn't allow to unpack tensors as tuples
+        sequence_idx = supervision_segments[idx, 0].item()
+        start_frame = supervision_segments[idx, 1].item()
+        num_frames = supervision_segments[idx, 2].item()
+        lengths[sequence_idx] = start_frame + num_frames
+
+    lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
+    bs = int(len(lengths))
+    seq_range = torch.arange(0, max_len, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
+    # Note: TorchScript doesn't implement Tensor.new()
+    seq_length_expand = torch.tensor(
+        lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
+    ).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    return mask
+
+
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
+    """Generate a length mask for input.
+
+    The masked position are filled with True,
+    Unmasked positions are filled with False.
+
+    Args:
+      ys_pad:
+        padded tensor of dimension (batch_size, input_length).
+      ignore_id:
+        the ignored number (the padding number) in ys_pad
+
+    Returns:
+        A bool tensor of the same shape as the input tensor.
+    """
+    ys_mask = ys_pad == ignore_id
+    return ys_mask
+
+
+def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
+    """Generate a square mask for the sequence. The masked positions are
+    filled with float('-inf'). Unmasked positions are filled with float(0.0).
+    The mask can be used for masked self-attention.
+
+    For instance, if sz is 3, it returns::
+
+        tensor([[0., -inf, -inf],
+                [0., 0., -inf],
+                [0., 0., 0]])
+
+    Args:
+      sz: mask size
+
+    Returns:
+      A square mask tensor of dimension (sz, sz)
+    """
+    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+    mask = (
+        mask.float()
+        .masked_fill(mask == 0, float("-inf"))
+        .masked_fill(mask == 1, float(0.0))
+    )
+    return mask
+
+
+def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
+    """Prepend sos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-list of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      sos_id:
+        The ID of the SOS token.
+
+    Return:
+      Return a new list-of-list, where each sublist starts
+      with SOS ID.
+    """
+    return [[sos_id] + utt for utt in token_ids]
+
+
+def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
+    """Append eos_id to each utterance.
+
+    Args:
+      token_ids:
+        A list-of-lists of token IDs. Each sublist contains
+        token IDs (e.g., word piece IDs) of an utterance.
+      eos_id:
+        The ID of the EOS token.
+
+    Return:
+      Return a new list-of-lists, where each sublist ends
+      with EOS ID.
+    """
+    return [utt + [eos_id] for utt in token_ids]
+
+
+def tolist(t: torch.Tensor) -> List[int]:
+    """Used by jit"""
+    return torch.jit.annotate(List[int], t.tolist())
diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
index 327962a79..733ebf235 100755
--- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
+++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py
@@ -83,9 +83,7 @@ def compute_fbank_tedlium():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cur_num_jobs = num_jobs if ex is None else 80
             cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -104,9 +102,7 @@ def compute_fbank_tedlium():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
index 49544ccb3..19ba8d24b 100644
--- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
+++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py
@@ -4,16 +4,18 @@
 """
 Convert a transcript based on words to a list of BPE ids.
 
-For example, if we use 2 as the encoding id of :
+For example, if we use 2 as the encoding id of 
+Note: it, inserts a space token before each 
 
 texts = ['this is a  day']
-spm_ids = [[38, 33, 6, 2, 316]]
+spm_ids = [[38, 33, 6, 15, 2, 316]]
 
 texts = [' this is a sunny day']
-spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]]
+spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]]
 
 texts = ['']
-spm_ids = [[2]]
+spm_ids = [[15, 2]]
+
 """
 
 import argparse
@@ -25,9 +27,7 @@ import sentencepiece as spm
 
 def get_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--texts", type=List[str], help="The input transcripts list."
-    )
+    parser.add_argument("--texts", type=List[str], help="The input transcripts list.")
     parser.add_argument(
         "--bpe-model",
         type=str,
@@ -40,29 +40,27 @@ def get_args():
 
 def convert_texts_into_ids(
     texts: List[str],
-    unk_id: int,
     sp: spm.SentencePieceProcessor,
 ) -> List[List[int]]:
     """
     Args:
       texts:
         A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
-      unk_id:
-        A number id for the token ''.
+      sp:
+        A sentencepiece BPE model.
     Returns:
       Return an integer list of bpe ids.
     """
     y = []
     for text in texts:
-        y_ids = []
         if "" in text:
-            text_segments = text.split("")
-            id_segments = sp.encode(text_segments, out_type=int)
+            id_segments = sp.encode(text.split(""), out_type=int)
+
+            y_ids = []
             for i in range(len(id_segments)):
-                if i != len(id_segments) - 1:
-                    y_ids.extend(id_segments[i] + [unk_id])
-                else:
-                    y_ids.extend(id_segments[i])
+                y_ids += id_segments[i]
+                if i < len(id_segments) - 1:
+                    y_ids += [sp.piece_to_id("▁"), sp.unk_id()]
         else:
             y_ids = sp.encode(text, out_type=int)
         y.append(y_ids)
@@ -72,19 +70,13 @@ def convert_texts_into_ids(
 
 def main():
     args = get_args()
-    texts = args.texts
-    bpe_model = args.bpe_model
 
     sp = spm.SentencePieceProcessor()
-    sp.load(bpe_model)
-    unk_id = sp.piece_to_id("")
+    sp.load(args.bpe_model)
 
-    y = convert_texts_into_ids(
-        texts=texts,
-        unk_id=unk_id,
-        sp=sp,
-    )
-    logging.info(f"The input texts: {texts}")
+    y = convert_texts_into_ids(texts=args.texts, sp=sp)
+
+    logging.info(f"The input texts: {args.texts}")
     logging.info(f"The encoding ids: {y}")
 
 
diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py
deleted file mode 100755
index 35dd332e8..000000000
--- a/egs/tedlium3/ASR/local/prepare_lexicon.py
+++ /dev/null
@@ -1,97 +0,0 @@
-#!/usr/bin/env python3
-# Copyright    2022  Xiaomi Corp.        (authors: Mingshuang Luo)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This script takes as input supervisions json dir "data/manifests"
-consisting of supervisions_train.json and does the following:
-
-1. Generate lexicon_words.txt.
-
-"""
-import lhotse
-import argparse
-import logging
-from pathlib import Path
-
-
-def get_args():
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "--manifests-dir",
-        type=str,
-        help="""Input directory.
-        """,
-    )
-    parser.add_argument(
-        "--lang-dir",
-        type=str,
-        help="""Output directory.
-        """,
-    )
-
-    return parser.parse_args()
-
-
-def prepare_lexicon(manifests_dir: str, lang_dir: str):
-    """
-    Args:
-      manifests_dir:
-        The manifests directory, e.g., data/manifests.
-      lang_dir:
-        The language directory, e.g., data/lang_phone.
-
-    Return:
-      The lexicon_words.txt file.
-    """
-    words = set()
-
-    lexicon = Path(lang_dir) / "lexicon_words.txt"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
-    for s in sups:
-        # list the words units and filter the empty item
-        words_list = list(filter(None, s.text.split()))
-
-        for word in words_list:
-            if word not in words and word != "":
-                words.add(word)
-
-    with open(lexicon, "w") as f:
-        for word in sorted(words):
-            f.write(word + "  " + word)
-            f.write("\n")
-
-
-def main():
-    args = get_args()
-    manifests_dir = Path(args.manifests_dir)
-    lang_dir = Path(args.lang_dir)
-
-    logging.info("Generating lexicon_words.txt")
-    prepare_lexicon(manifests_dir, lang_dir)
-
-
-if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
-
-    logging.basicConfig(format=formatter, level=logging.INFO)
-
-    main()
diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py
index 1039ac5bb..d4ccdd1e3 100755
--- a/egs/tedlium3/ASR/local/prepare_transcripts.py
+++ b/egs/tedlium3/ASR/local/prepare_transcripts.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
-# Copyright    2021  Xiaomi Corp.        (authors: Mingshuang Luo)
+# Copyright    2021  Xiaomi Corp.        (author: Mingshuang Luo)
+# Copyright    2022  Behavox LLC.        (author: Daniil Kulko)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -17,75 +18,71 @@
 
 
 """
-This script takes as input supervisions json dir "data/manifests"
-consisting of supervisions_train.json and does the following:
-
-1. Generate train.text.
+This script takes input text file and removes all words
+that iclude any character out of English alphabet.
 
 """
-import lhotse
 import argparse
 import logging
+import re
 from pathlib import Path
 
 
 def get_args():
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "--manifests-dir",
+        "--input-text-path",
         type=str,
-        help="""Input directory.
-        """,
+        help="Input text file path.",
     )
     parser.add_argument(
-        "--lang-dir",
+        "--output-text-path",
         type=str,
-        help="""Output directory.
-        """,
+        help="Output text file path.",
     )
 
     return parser.parse_args()
 
 
-def prepare_transcripts(manifests_dir: str, lang_dir: str):
+def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None:
     """
     Args:
-      manifests_dir:
-        The manifests directory, e.g., data/manifests.
-      lang_dir:
-        The language directory, e.g., data/lang_phone.
+      input_text_path:
+        The input data text file path, e.g., data/lang/train_orig.txt.
+      output_text_path:
+        The output data text file path, e.g., data/lang/train.txt.
 
     Return:
-      The train.text in lang_dir.
+      Saved text file in output_text_path.
     """
-    texts = []
 
-    train_text = Path(lang_dir) / "train.text"
-    sups = lhotse.load_manifest(
-        f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz"
-    )
-    for s in sups:
-        texts.append(s.text)
+    foreign_chr_check = re.compile(r"[^a-z']")
 
-    with open(train_text, "w") as f:
-        for text in texts:
-            f.write(text)
-            f.write("\n")
+    logging.info(f"Loading {input_text_path.name}")
+    with open(input_text_path, "r", encoding="utf8") as f:
+        texts = {t.rstrip("\n") for t in f}
+
+    texts = {
+        " ".join([w for w in t.split() if foreign_chr_check.search(w) is None])
+        for t in texts
+    }
+
+    with open(output_text_path, "w+", encoding="utf8") as f:
+        for t in texts:
+            f.write(f"{t}\n")
 
 
-def main():
+def main() -> None:
     args = get_args()
-    manifests_dir = Path(args.manifests_dir)
-    lang_dir = Path(args.lang_dir)
+    input_text_path = Path(args.input_text_path)
+    output_text_path = Path(args.output_text_path)
 
-    logging.info("Generating train.text")
-    prepare_transcripts(manifests_dir, lang_dir)
+    logging.info(f"Generating {output_text_path.name}")
+    prepare_transcripts(input_text_path, output_text_path)
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/tedlium3/ASR/local/prepare_words.py b/egs/tedlium3/ASR/local/prepare_words.py
new file mode 100755
index 000000000..a37d0f08f
--- /dev/null
+++ b/egs/tedlium3/ASR/local/prepare_words.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+# Copyright    2022  Behavox LLC.        (authors: Daniil Kulko)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input supervisions json dir "data/manifests"
+consisting of tedlium_supervisions_train.json and does the following:
+
+1. Generate words.txt.
+
+"""
+import argparse
+import logging
+import re
+from pathlib import Path
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        help="Output directory.",
+    )
+
+    return parser.parse_args()
+
+
+def prepare_words(lang_dir: str) -> None:
+    """
+    Args:
+      lang_dir:
+        The language directory, e.g., data/lang.
+
+    Return:
+      The words.txt file.
+    """
+
+    words_orig_path = Path(lang_dir) / "words_orig.txt"
+    words_path = Path(lang_dir) / "words.txt"
+
+    foreign_chr_check = re.compile(r"[^a-z']")
+
+    logging.info(f"Loading {words_orig_path.name}")
+    with open(words_orig_path, "r", encoding="utf8") as f:
+        words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")}
+    words = {w for w in words if foreign_chr_check.search(w) is None and w != ""}
+    words.add("")
+    words = ["", "!SIL"] + sorted(words) + ["#0", "", ""]
+
+    with open(words_path, "w+", encoding="utf8") as f:
+        for idx, word in enumerate(words):
+            f.write(f"{word} {idx}\n")
+
+
+def main() -> None:
+    args = get_args()
+    lang_dir = Path(args.lang_dir)
+
+    logging.info("Generating words.txt")
+    prepare_words(lang_dir)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+
+    main()
diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh
index ccb307a52..3d90436ff 100755
--- a/egs/tedlium3/ASR/prepare.sh
+++ b/egs/tedlium3/ASR/prepare.sh
@@ -1,8 +1,10 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
-nj=15
 stage=0
 stop_stage=100
 
@@ -60,6 +62,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3
   fi
 
+  # Download big and small 4 gram lanuage models
+  if [ ! -d $dl_dir/lm ]; then
+    wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm
+    wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm
+    gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz
+  fi
+
   # If you have pre-downloaded it to /path/to/musan,
   # you can create a symlink
   #
@@ -97,7 +106,14 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
 
   if [ ! -e data/fbank/.tedlium3.done ]; then
     mkdir -p data/fbank
+
     python3 ./local/compute_fbank_tedlium.py
+
+    gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \
+    gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz
+    mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \
+       data/fbank/tedlium_cuts_train.jsonl.gz
+
     touch data/fbank/.tedlium3.done
   fi
 fi
@@ -112,28 +128,24 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
 fi
 
 if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
-  log "Stage 5: Prepare phone based lang"
-  lang_dir=data/lang_phone
+  log "Stage 5: Prepare BPE train data and set of words"
+  lang_dir=data/lang
   mkdir -p $lang_dir
 
-  if [ ! -f $lang_dir/train.text ]; then
+  if [ ! -f $lang_dir/train.txt ]; then
+    gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt
+
     ./local/prepare_transcripts.py \
-      --lang-dir $lang_dir \
-      --manifests-dir data/manifests
+      --input-text-path $lang_dir/train_orig.txt \
+      --output-text-path $lang_dir/train.txt
   fi
 
-  if [ ! -f $lang_dir/lexicon_words.txt ]; then
-    ./local/prepare_lexicon.py \
-      --lang-dir $lang_dir \
-      --manifests-dir data/manifests
-  fi
+  if [ ! -f $lang_dir/words.txt ]; then
 
-  (echo '!SIL SIL'; echo ' '; ) |
-    cat - $lang_dir/lexicon_words.txt |
-    sort | uniq > $lang_dir/lexicon.txt
+    awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic |
+    sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt
 
-  if [ ! -f $lang_dir/L_disambig.pt ]; then
-    ./local/prepare_lang.py --lang-dir $lang_dir
+    ./local/prepare_words.py --lang-dir $lang_dir
   fi
 fi
 
@@ -145,25 +157,56 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
     mkdir -p $lang_dir
     # We reuse words.txt from phone based lexicon
     # so that the two can share G.pt later.
-    cp data/lang_phone/words.txt $lang_dir
-
-    if [ ! -f $lang_dir/transcript_words.txt ]; then
-      log "Generate data for BPE training"
-      cat data/lang_phone/train.text |
-      cut -d " " -f 2- > $lang_dir/transcript_words.txt
-      # remove the  for transcript_words.txt
-      sed -i 's/ //g' $lang_dir/transcript_words.txt
-      sed -i 's/ //g' $lang_dir/transcript_words.txt
-      sed -i 's///g' $lang_dir/transcript_words.txt
-    fi
+    cp data/lang/words.txt $lang_dir
 
     ./local/train_bpe_model.py \
       --lang-dir $lang_dir \
       --vocab-size $vocab_size \
-      --transcript $lang_dir/transcript_words.txt
+      --transcript data/lang/train.txt
 
     if [ ! -f $lang_dir/L_disambig.pt ]; then
-      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov ""
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+
+  mkdir -p data/lm
+  if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then
+    # It is used in building HLG
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      --max-arpa-warnings=-1 \
+      $dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt
+  fi
+
+  if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then
+    # It is used for LM rescoring
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      --max-arpa-warnings=-1 \
+      $dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt
+  fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Compile HLG"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/HLG.pt ]; then
+      ./local/compile_hlg.py \
+        --lang-dir $lang_dir \
+        --lm G_4_gram_small
     fi
   done
 fi
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
index 2b294e601..abba9d403 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py
@@ -172,8 +172,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -231,9 +230,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -248,10 +245,7 @@ def decode_one_batch(
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -374,9 +368,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -387,18 +379,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -408,10 +396,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
index a1c3bcea3..aa22f82ec 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py
@@ -106,8 +106,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -179,9 +178,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8480ac029..8a89c3578 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -203,10 +202,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -271,9 +269,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -298,10 +294,7 @@ def main():
         )
         for hyp in sp.decode(hyp_tokens):
             hyps.append(hyp.split())
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -353,9 +346,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 8d5cdf683..170f37767 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -157,8 +156,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -556,9 +554,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -678,9 +674,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index 51de46ae8..c647392f0 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -17,10 +17,10 @@
 
 
 import argparse
-import inspect
 import logging
 from functools import lru_cache
 from pathlib import Path
+from typing import Any, Dict, Optional
 
 from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
 from lhotse.dataset import (
@@ -28,7 +28,6 @@ from lhotse.dataset import (
     CutMix,
     DynamicBucketingSampler,
     K2SpeechRecognitionDataset,
-    PrecomputedFeatures,
     SingleCutSampler,
     SpecAugment,
 )
@@ -140,7 +139,6 @@ class TedLiumAsrDataModule:
             "field: batch['supervisions']['cut'] with the cuts that "
             "were used to construct it.",
         )
-
         group.add_argument(
             "--num-workers",
             type=int,
@@ -148,14 +146,12 @@ class TedLiumAsrDataModule:
             help="The number of training dataloader workers that "
             "collect the batches.",
         )
-
         group.add_argument(
             "--enable-spec-aug",
             type=str2bool,
             default=True,
             help="When enabled, use SpecAugment for training dataset.",
         )
-
         group.add_argument(
             "--spec-aug-time-warp-factor",
             type=int,
@@ -165,27 +161,51 @@ class TedLiumAsrDataModule:
             "Larger values mean more warping. "
             "A value less than 1 means to disable time warp.",
         )
-
         group.add_argument(
             "--enable-musan",
             type=str2bool,
             default=True,
             help="When enabled, select noise from MUSAN and mix it"
-            "with training dataset. ",
+            "with training dataset.",
         )
 
-    def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
+    def train_dataloaders(
+        self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=10,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                    max_frames_mask_fraction=0.15,
+                    p=0.9,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
         logging.info("About to get Musan cuts")
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
-            cuts_musan = load_manifest(
-                self.args.manifest_dir / "musan_cuts.jsonl.gz"
-            )
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -204,42 +224,7 @@ class TedLiumAsrDataModule:
                 )
             ] + transforms
 
-        input_transforms = []
-        if self.args.enable_spec_aug:
-            logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
-            # Set the value of num_frame_masks according to Lhotse's version.
-            # In different Lhotse's versions, the default of num_frame_masks is
-            # different.
-            num_frame_masks = 10
-            num_frame_masks_parameter = inspect.signature(
-                SpecAugment.__init__
-            ).parameters["num_frame_masks"]
-            if num_frame_masks_parameter.default == 1:
-                num_frame_masks = 2
-            logging.info(f"Num frame mask: {num_frame_masks}")
-            input_transforms.append(
-                SpecAugment(
-                    time_warp_factor=self.args.spec_aug_time_warp_factor,
-                    num_frame_masks=num_frame_masks,
-                    features_mask_size=27,
-                    num_feature_masks=2,
-                    frames_mask_size=100,
-                    max_frames_mask_fraction=0.15,
-                    p=0.9,
-                )
-            )
-        else:
-            logging.info("Disable SpecAugment")
-
         logging.info("About to create train dataset")
-        train = K2SpeechRecognitionDataset(
-            cut_transforms=transforms,
-            input_transforms=input_transforms,
-            return_cuts=self.args.return_cuts,
-        )
         if self.args.on_the_fly_feats:
             # NOTE: the PerturbSpeed transform should be added only if we
             # remove it from data prep stage.
@@ -253,9 +238,13 @@ class TedLiumAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -276,6 +265,11 @@ class TedLiumAsrDataModule:
                 max_duration=self.args.max_duration,
                 shuffle=self.args.shuffle,
             )
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
         logging.info("About to create train dataloader")
         train_dl = DataLoader(
             train,
@@ -288,6 +282,7 @@ class TedLiumAsrDataModule:
         return train_dl
 
     def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
         transforms = []
         if self.args.concatenate_cuts:
             transforms = [
@@ -300,9 +295,7 @@ class TedLiumAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -310,11 +303,13 @@ class TedLiumAsrDataModule:
                 cut_transforms=transforms,
                 return_cuts=self.args.return_cuts,
             )
+
         valid_sampler = DynamicBucketingSampler(
             cuts_valid,
             max_duration=self.args.max_duration,
             shuffle=False,
         )
+
         logging.info("About to create dev dataloader")
         valid_dl = DataLoader(
             validate,
@@ -326,25 +321,32 @@ class TedLiumAsrDataModule:
 
         return valid_dl
 
-    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+    def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
+
         logging.debug("About to create test dataset")
-        test = K2SpeechRecognitionDataset(
-            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
-            if self.args.on_the_fly_feats
-            else PrecomputedFeatures(),
-            return_cuts=self.args.return_cuts,
-        )
-        sampler = DynamicBucketingSampler(
-            cuts,
+        if self.args.on_the_fly_feats:
+            test = K2SpeechRecognitionDataset(
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            test = K2SpeechRecognitionDataset(
+                return_cuts=self.args.return_cuts,
+            )
+
+        test_sampler = DynamicBucketingSampler(
+            cuts_test,
             max_duration=self.args.max_duration,
             shuffle=False,
         )
+
         logging.debug("About to create test dataloader")
         test_dl = DataLoader(
             test,
             batch_size=None,
-            sampler=sampler,
+            sampler=test_sampler,
             num_workers=self.args.num_workers,
+            persistent_workers=False,
         )
         return test_dl
 
@@ -358,13 +360,9 @@ class TedLiumAsrDataModule:
     @lru_cache()
     def dev_cuts(self) -> CutSet:
         logging.info("About to get dev cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz")
 
     @lru_cache()
     def test_cuts(self) -> CutSet:
         logging.info("About to get test cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
index 77caf6460..1f99edaf3 100644
--- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py
+++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py
@@ -87,9 +87,9 @@ def greedy_search(
         y = logits.argmax().item()
         if y != blank_id and y != unk_id:
             hyp.append(y)
-            decoder_input = torch.tensor(
-                [hyp[-context_size:]], device=device
-            ).reshape(1, context_size)
+            decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+                1, context_size
+            )
 
             decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -148,9 +148,7 @@ class HypothesisList(object):
         key = hyp.key
         if key in self:
             old_hyp = self._data[key]  # shallow copy
-            torch.logaddexp(
-                old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
-            )
+            torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
         else:
             self._data[key] = hyp
 
@@ -166,9 +164,7 @@ class HypothesisList(object):
           Return the hypothesis that has the largest `log_prob`.
         """
         if length_norm:
-            return max(
-                self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
-            )
+            return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
         else:
             return max(self._data.values(), key=lambda hyp: hyp.log_prob)
 
@@ -344,9 +340,9 @@ def modified_beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
@@ -383,9 +379,7 @@ def modified_beam_search(
         decoder_out = model.decoder(decoder_input, need_pad=False)
         # decoder_output is of shape (num_hyps, 1, decoder_output_dim)
 
-        current_encoder_out = current_encoder_out.expand(
-            decoder_out.size(0), 1, -1
-        )
+        current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1)
 
         logits = model.joiner(
             current_encoder_out,
@@ -454,9 +448,9 @@ def beam_search(
 
     device = model.device
 
-    decoder_input = torch.tensor(
-        [blank_id] * context_size, device=device
-    ).reshape(1, context_size)
+    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+        1, context_size
+    )
 
     decoder_out = model.decoder(decoder_input, need_pad=False)
 
diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py
index d3e9e55e7..fb0e3116b 100755
--- a/egs/tedlium3/ASR/transducer_stateless/decode.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decode.py
@@ -130,8 +130,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -250,9 +249,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
     batch_size = encoder_out.size(0)
 
@@ -275,9 +272,7 @@ def decode_one_batch(
                 model=model, encoder_out=encoder_out_i, beam=params.beam_size
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
         hyps.append(sp.decode(hyp).split())
 
     if params.decoding_method == "greedy_search":
@@ -348,9 +343,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -361,18 +354,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -382,10 +371,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py
index f0c6f32b6..f9a3814c6 100644
--- a/egs/tedlium3/ASR/transducer_stateless/decoder.py
+++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py
@@ -90,9 +90,7 @@ class Decoder(nn.Module):
         if self.context_size > 1:
             embedding_out = embedding_out.permute(0, 2, 1)
             if need_pad is True:
-                embedding_out = F.pad(
-                    embedding_out, pad=(self.context_size - 1, 0)
-                )
+                embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
             else:
                 # During inference time, there is no need to do extra padding
                 # as we only need one output
diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py
index c32b1d002..48dcdc736 100644
--- a/egs/tedlium3/ASR/transducer_stateless/export.py
+++ b/egs/tedlium3/ASR/transducer_stateless/export.py
@@ -110,8 +110,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -247,9 +246,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index c0e3bb844..81afd6a4e 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -127,8 +127,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -222,10 +221,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -285,9 +283,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -335,9 +331,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 09cbf4a00..6fed32e81 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -133,8 +133,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -525,9 +524,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -647,9 +644,7 @@ def run(rank, world_size, args):
 
         cur_lr = optimizer._rate
         if tb_writer is not None:
-            tb_writer.add_scalar(
-                "train/learning_rate", cur_lr, params.batch_idx_train
-            )
+            tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
             tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
 
         if rank == 0:
diff --git a/egs/timit/ASR/README.md b/egs/timit/ASR/README.md
index f10bfccfd..d493fc479 100644
--- a/egs/timit/ASR/README.md
+++ b/egs/timit/ASR/README.md
@@ -1,3 +1,3 @@
 
-Please refer to 
+Please refer to 
 for how to run models in this recipe.
diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md
index b78c16b88..d8ceb82b6 100644
--- a/egs/timit/ASR/RESULTS.md
+++ b/egs/timit/ASR/RESULTS.md
@@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \
                                --avg 17 \
                                --max-duration 20 \
                                --lang-dir data/lang_phone
-```
\ No newline at end of file
+```
diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py
index 58cab4cf2..c8562f4fb 100644
--- a/egs/timit/ASR/local/compile_hlg.py
+++ b/egs/timit/ASR/local/compile_hlg.py
@@ -100,7 +100,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
 
     logging.info("Removing disambiguation symbols on LG")
 
-    LG.labels[LG.labels >= first_token_disambig_id] = 0
+    # LG.labels[LG.labels >= first_token_disambig_id] = 0
+    # see https://github.com/k2-fsa/k2/pull/1140
+    labels = LG.labels
+    labels[labels >= first_token_disambig_id] = 0
+    LG.labels = labels
 
     LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
 
@@ -146,9 +150,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py
index f25786a0c..ecdf10ba9 100644
--- a/egs/timit/ASR/local/compute_fbank_timit.py
+++ b/egs/timit/ASR/local/compute_fbank_timit.py
@@ -85,9 +85,7 @@ def compute_fbank_timit():
             )
             if partition == "TRAIN":
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -101,9 +99,7 @@ def compute_fbank_timit():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py
index 04023a9ab..0cf0f0deb 100644
--- a/egs/timit/ASR/local/prepare_lexicon.py
+++ b/egs/timit/ASR/local/prepare_lexicon.py
@@ -62,9 +62,7 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
 
     phones = set()
 
-    supervisions_train = (
-        Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
-    )
+    supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
     lexicon = Path(lang_dir) / "lexicon.txt"
 
     logging.info(f"Loading {supervisions_train}!")
@@ -97,9 +95,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh
index ae1b96a68..148a9f51b 100644
--- a/egs/timit/ASR/prepare.sh
+++ b/egs/timit/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 num_phones=39
@@ -20,9 +23,9 @@ stop_stage=100
 #  - $dl_dir/lm
 #      This directory contains the language model(LM) downloaded from
 #      https://huggingface.co/luomingshuang/timit_lm, and the LM is based
-#	     on 39 phones. About how to get these LM files, you can know it 
+#	     on 39 phones. About how to get these LM files, you can know it
 #      from https://github.com/luomingshuang/Train_LM_with_kaldilm.
-#	
+#
 #	    - lm_3_gram.arpa
 #     - lm_4_gram.arpa
 #
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
index 4f2aa2340..4beeed18c 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py
@@ -336,9 +336,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -400,9 +398,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -462,9 +458,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -485,9 +479,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py
index 4d2199ace..9a594a969 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/model.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py
@@ -16,11 +16,11 @@
 # limitations under the License.
 
 
+from typing import Optional
+
 import torch
 import torch.nn as nn
-
 from torch import Tensor
-from typing import Optional
 
 
 class TdnnLiGRU(nn.Module):
@@ -261,9 +261,7 @@ class LiGRU(torch.nn.Module):
         h = []
         if hx is not None:
             if self.bidirectional:
-                hx = hx.reshape(
-                    self.num_layers, self.batch_size * 2, self.hidden_size
-                )
+                hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
         # Processing the different layers
         for i, ligru_lay in enumerate(self.rnn):
             if hx is not None:
@@ -445,9 +443,7 @@ class LiGRU_Layer(torch.nn.Module):
             if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                 self.drop_mask_cnt = 0
                 self.drop_masks = self.drop(
-                    torch.ones(
-                        self.N_drop_masks, self.hidden_size, device=w.device
-                    )
+                    torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
                 ).data
 
             # Sampling the mask
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 7da285944..3fdf3b855 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLiGRU
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -144,10 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py
index 452c2a7cb..48b7feda0 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/train.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1554e987f..51ca4cc6e 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -154,9 +154,7 @@ class TimitAsrDataModule(DataModule):
         cuts_train = self.train_cuts()
 
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.feature_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz")
 
         logging.info("About to create train dataset")
         transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
@@ -178,9 +176,9 @@ class TimitAsrDataModule(DataModule):
         # In different Lhotse's versions, the default of num_frame_masks is
         # different.
         num_frame_masks = 10
-        num_frame_masks_parameter = inspect.signature(
-            SpecAugment.__init__
-        ).parameters["num_frame_masks"]
+        num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[
+            "num_frame_masks"
+        ]
         if num_frame_masks_parameter.default == 1:
             num_frame_masks = 2
         logging.info(f"Num frame mask: {num_frame_masks}")
@@ -212,9 +210,7 @@ class TimitAsrDataModule(DataModule):
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -263,9 +259,7 @@ class TimitAsrDataModule(DataModule):
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -299,20 +293,14 @@ class TimitAsrDataModule(DataModule):
         for cuts_test in cuts:
             logging.debug("About to create test dataset")
             test = K2SpeechRecognitionDataset(
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                )
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
                 if self.args.on_the_fly_feats
                 else PrecomputedFeatures(),
                 return_cuts=self.args.return_cuts,
             )
-            sampler = SingleCutSampler(
-                cuts_test, max_duration=self.args.max_duration
-            )
+            sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
             logging.debug("About to create test dataloader")
-            test_dl = DataLoader(
-                test, batch_size=None, sampler=sampler, num_workers=1
-            )
+            test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
             test_loaders.append(test_dl)
 
         if is_list:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
index 5e7300cf2..502a48def 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py
@@ -335,9 +335,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -399,9 +397,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -461,9 +457,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -483,9 +477,7 @@ def main():
         G=G,
     )
 
-    save_results(
-        params=params, test_set_name=test_set, results_dict=results_dict
-    )
+    save_results(params=params, test_set_name=test_set, results_dict=results_dict)
 
     logging.info("Done!")
 
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py
index 51edb97e2..e211ad80d 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py
@@ -74,10 +74,7 @@ class TdnnLstm(nn.Module):
             nn.BatchNorm1d(num_features=512, affine=False),
         )
         self.lstms = nn.ModuleList(
-            [
-                nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
-                for _ in range(4)
-            ]
+            [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)]
         )
         self.lstm_bnorms = nn.ModuleList(
             [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 5f478da1c..98c746ce5 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -29,11 +29,7 @@ import torchaudio
 from model import TdnnLstm
 from torch.nn.utils.rnn import pad_sequence
 
-from icefall.decode import (
-    get_lattice,
-    one_best_decoding,
-    rescore_with_whole_lattice,
-)
+from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice
 from icefall.utils import AttributeDict, get_texts
 
 
@@ -58,9 +54,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "--method",
@@ -144,10 +138,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -215,9 +208,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
     features = features.permute(0, 2, 1)  # now features is (N, C, T)
 
     with torch.no_grad():
@@ -269,9 +260,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py
index 849256b98..be1ecffaa 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py
@@ -449,9 +449,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/wenetspeech/ASR/finetune.sh b/egs/wenetspeech/ASR/finetune.sh
new file mode 100755
index 000000000..8559780e9
--- /dev/null
+++ b/egs/wenetspeech/ASR/finetune.sh
@@ -0,0 +1,82 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# This is an example script for fine-tuning. Here, we fine-tune a model trained
+# on WenetSpeech on Aishell. The model used for fine-tuning is
+# pruned_transducer_stateless2 (zipformer). If you want to fine-tune model
+# from another recipe, you can adapt ./pruned_transducer_stateless2/finetune.py
+# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues.
+
+# We assume that you have already prepared the Aishell manfiest&features under ./data.
+# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/prepare.sh.
+
+. shared/parse_options.sh || exit 1
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+  log "Stage -1: Download Pre-trained model"
+
+  # clone from huggingface
+  git lfs install
+  git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2
+
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Start fine-tuning"
+
+  # The following configuration of lr schedule should work well
+  # You may also tune the following parameters to adjust learning rate schedule
+  initial_lr=0.0001
+  lr_epochs=100
+  lr_batches=100000
+
+  # We recommend to start from an averaged model
+  finetune_ckpt=icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/pretrained_epoch_10_avg_2.pt
+  lang_dir=icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char
+  export CUDA_VISIBLE_DEVICES="0,1"
+
+  ./pruned_transducer_stateless2/finetune.py \
+    --world-size 2 \
+    --master-port 18180 \
+    --num-epochs 15 \
+    --context-size 2 \
+    --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \
+    --initial-lr $initial_lr \
+    --lr-epochs $lr_epochs \
+    --lr-batches $lr_batches \
+    --lang-dir $lang_dir \
+    --do-finetune True \
+    --finetune-ckpt $finetune_ckpt \
+    --max-duration 200
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Decoding"
+
+  epoch=4
+  avg=4
+
+  for m in greedy_search modified_beam_search; do
+    python pruned_transducer_stateless2/decode_aishell.py \
+    --epoch $epoch \
+    --avg $avg \
+    --context-size 2 \
+    --beam-size 4 \
+    --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \
+    --max-duration 400 \
+    --decoding-method $m
+  done
+fi
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
index 8a9f6ed30..20d7341db 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py
@@ -20,12 +20,7 @@ import logging
 from pathlib import Path
 
 import torch
-from lhotse import (
-    CutSet,
-    KaldifeatFbank,
-    KaldifeatFbankConfig,
-    LilcomHdf5Writer,
-)
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
 
 # Torch's multithreaded behavior needs to be disabled or
 # it wastes a lot of CPU and slow things down.
@@ -74,7 +69,7 @@ def compute_fbank_wenetspeech_dev_test():
             storage_path=f"{in_out_dir}/feats_{partition}",
             num_workers=num_workers,
             batch_duration=batch_duration,
-            storage_type=LilcomHdf5Writer,
+            storage_type=LilcomChunkyWriter,
             overwrite=True,
         )
 
@@ -83,9 +78,7 @@ def compute_fbank_wenetspeech_dev_test():
 
 
 def main():
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     logging.basicConfig(format=formatter, level=logging.INFO)
 
     compute_fbank_wenetspeech_dev_test()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index a882b6113..1b257fb70 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -152,9 +152,7 @@ def main():
     date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
 
     log_filename = "log-compute_fbank_wenetspeech_splits"
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
     log_filename = f"{log_filename}-{date_time}"
 
     logging.basicConfig(
diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py
index c41445b8d..36e4ac5c3 100644
--- a/egs/wenetspeech/ASR/local/display_manifest_statistics.py
+++ b/egs/wenetspeech/ASR/local/display_manifest_statistics.py
@@ -33,6 +33,7 @@ def main():
     paths = [
         "./data/fbank/cuts_S.jsonl.gz",
         "./data/fbank/cuts_M.jsonl.gz",
+        "./data/fbank/cuts_L.jsonl.gz",
         "./data/fbank/cuts_DEV.jsonl.gz",
         "./data/fbank/cuts_TEST_NET.jsonl.gz",
         "./data/fbank/cuts_TEST_MEETING.jsonl.gz",
@@ -48,6 +49,24 @@ if __name__ == "__main__":
     main()
 
 """
+Starting display the statistics for ./data/fbank/cuts_L.jsonl.gz
+
+Cuts count: 43874235
+Total duration (hours): 30217.3
+Speech duration (hours): 30217.3 (100.0%)
+***
+Duration statistics (seconds):
+mean    2.5
+std     1.7
+min     0.2
+25%     1.4
+50%     2.0
+75%     3.0
+99%     8.4
+99.5%   9.1
+99.9%   15.4
+max     405.1
+
 Starting display the statistics for ./data/fbank/cuts_S.jsonl.gz
 Duration statistics (seconds):
 mean    2.4
diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py
index 8bc073c75..d8622842f 100755
--- a/egs/wenetspeech/ASR/local/prepare_char.py
+++ b/egs/wenetspeech/ASR/local/prepare_char.py
@@ -83,9 +83,7 @@ def lexicon_to_fst_no_sil(
         cur_state = loop_state
 
         word = word2id[word]
-        pieces = [
-            token2id[i] if i in token2id else token2id[""] for i in pieces
-        ]
+        pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
 
         for i in range(len(pieces) - 1):
             w = word if i == 0 else eps
@@ -138,9 +136,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
     return False
 
 
-def generate_lexicon(
-    token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
     """Generate a lexicon from a word list and token_sym_table.
     Args:
       token_sym_table:
diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
index 817969c47..93ce750f8 100755
--- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
+++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py
@@ -115,11 +115,7 @@ def preprocess_wenet_speech():
                 f"Speed perturb for {partition} with factors 0.9 and 1.1 "
                 "(Perturbing may take 8 minutes and saving may take 20 minutes)"
             )
-            cut_set = (
-                cut_set
-                + cut_set.perturb_speed(0.9)
-                + cut_set.perturb_speed(1.1)
-            )
+            cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
         logging.info(f"Saving to {raw_cuts_path}")
         cut_set.to_file(raw_cuts_path)
 
diff --git a/egs/wenetspeech/ASR/local/text2segments.py b/egs/wenetspeech/ASR/local/text2segments.py
index df5b3c119..bdf5a3984 100644
--- a/egs/wenetspeech/ASR/local/text2segments.py
+++ b/egs/wenetspeech/ASR/local/text2segments.py
@@ -40,8 +40,8 @@ from tqdm import tqdm
 # and 'data()' is only supported in static graph mode. So if you
 # want to use this api, should call 'paddle.enable_static()' before
 # this api to enter static graph mode.
-paddle.enable_static()
-paddle.disable_signal_handler()
+# paddle.enable_static()
+# paddle.disable_signal_handler()
 jieba.enable_paddle()
 
 
diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py
index 1c463cf1c..d1d237a68 100755
--- a/egs/wenetspeech/ASR/local/text2token.py
+++ b/egs/wenetspeech/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
     parser.add_argument(
         "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
     )
-    parser.add_argument(
-        "--space", default="", type=str, help="space symbol"
-    )
+    parser.add_argument("--space", default="", type=str, help="space symbol")
     parser.add_argument(
         "--non-lang-syms",
         "-l",
@@ -66,9 +64,7 @@ def get_parser():
         type=str,
         help="list of non-linguistic symobles, e.g.,  etc.",
     )
-    parser.add_argument(
-        "text", type=str, default=False, nargs="?", help="input text"
-    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
     parser.add_argument(
         "--trans_type",
         "-t",
@@ -108,8 +104,7 @@ def token2id(
             if token_type == "lazy_pinyin":
                 text = lazy_pinyin(chars_list)
                 sub_ids = [
-                    token_table[txt] if txt in token_table else oov_id
-                    for txt in text
+                    token_table[txt] if txt in token_table else oov_id for txt in text
                 ]
                 ids.append(sub_ids)
             else:  # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
     if args.text:
         f = codecs.open(args.text, encoding="utf-8")
     else:
-        f = codecs.getreader("utf-8")(
-            sys.stdin if is_python2 else sys.stdin.buffer
-        )
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
 
     sys.stdout = codecs.getwriter("utf-8")(
         sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 755fbb2d7..f7b521794 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 nj=15
@@ -190,7 +193,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
   mkdir -p $lang_char_dir
 
   if ! which jq; then
-      echo "This script is intended to be used with jq but you have not installed jq 
+      echo "This script is intended to be used with jq but you have not installed jq
       Note: in Linux, you can install jq with the following command:
       1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
       2. chmod +x ./jq
@@ -258,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
   log "Stage 18: Compile LG"
   python ./local/compile_lg.py --lang-dir $lang_char_dir
 fi
+
+# prepare RNNLM data
+if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
+  log "Stage 19: Prepare LM training data"
+
+  log "Processing char based data"
+  text_out_dir=data/lm_char
+
+  mkdir -p $text_out_dir
+
+  log "Genearating training text data"
+  
+  if [ ! -f $text_out_dir/lm_data.pt ]; then
+    ./local/prepare_char_lm_training_data.py \
+      --lang-char data/lang_char \
+      --lm-data $lang_char_dir/text_words_segmentation \
+      --lm-archive $text_out_dir/lm_data.pt
+  fi
+
+  log "Generating DEV text data"
+  # prepare validation text data 
+  if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
+    valid_text=${text_out_dir}/
+
+    gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
+      | jq '.text' | sed 's/"//g' \
+      | ./local/text2token.py -t "char" > $text_out_dir/valid_text
+    
+    python3 ./local/text2segments.py \
+      --num-process $nj \
+      --input-file $text_out_dir/valid_text \
+      --output-file $text_out_dir/valid_text_words_segmentation
+  fi
+
+  ./local/prepare_char_lm_training_data.py \
+    --lang-char data/lang_char \
+    --lm-data $text_out_dir/valid_text_words_segmentation \
+    --lm-archive $text_out_dir/lm_data_valid.pt
+
+  # prepare TEST text data 
+  if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
+    log "Prepare text for test set."
+    for test_set in TEST_MEETING TEST_NET; do
+        gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \
+          | jq '.text' | sed 's/"//g' \
+          | ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text
+
+        python3 ./local/text2segments.py \
+          --num-process $nj \
+          --input-file $text_out_dir/${test_set}_text \
+          --output-file $text_out_dir/${test_set}_text_words_segmentation
+    done
+    
+    cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
+  fi
+
+  ./local/prepare_char_lm_training_data.py \
+    --lang-char data/lang_char \
+    --lm-data $text_out_dir/test_text_words_segmentation \
+    --lm-archive $text_out_dir/lm_data_test.pt
+
+fi
+
+# sort RNNLM data
+if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
+  text_out_dir=data/lm_char
+
+  log "Sort lm data"
+
+  ./local/sort_lm_training_data.py \
+    --in-lm-data $text_out_dir/lm_data.pt \
+    --out-lm-data $text_out_dir/sorted_lm_data.pt \
+    --out-statistics $text_out_dir/statistics.txt
+
+  ./local/sort_lm_training_data.py \
+    --in-lm-data $text_out_dir/lm_data_valid.pt \
+    --out-lm-data $text_out_dir/sorted_lm_data-valid.pt \
+    --out-statistics $text_out_dir/statistics-valid.txt
+
+  ./local/sort_lm_training_data.py \
+    --in-lm-data $text_out_dir/lm_data_test.pt \
+    --out-lm-data $text_out_dir/sorted_lm_data-test.pt \
+    --out-statistics $text_out_dir/statistics-test.txt
+fi
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
+  log "Stage 21: Train RNN LM model"
+  python ../../../icefall/rnn_lm/train.py \
+    --start-epoch 0 \
+    --world-size 2 \
+    --num-epochs 20 \
+    --use-fp16 0 \
+    --embedding-dim 2048 \
+    --hidden-dim 2048 \
+    --num-layers 2 \
+    --batch-size 400 \
+    --exp-dir rnnlm_char/exp \
+    --lm-data data/lm_char/sorted_lm_data.pt \
+    --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
+    --vocab-size 4336 \
+    --master-port 12340
+fi
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py
new file mode 120000
index 000000000..f7321272b
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py
@@ -0,0 +1 @@
+../../../aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 10c953e3b..c9e30e737 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -46,9 +46,6 @@ from torch.utils.data import DataLoader
 
 from icefall.utils import str2bool
 
-set_caching_enabled(False)
-torch.set_num_threads(1)
-
 
 class _SeedWorkers:
     def __init__(self, seed: int):
@@ -212,17 +209,13 @@ class WenetSpeechAsrDataModule:
             The state dict for the training sampler.
         """
         logging.info("About to get Musan cuts")
-        cuts_musan = load_manifest(
-            self.args.manifest_dir / "musan_cuts.jsonl.gz"
-        )
+        cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
 
         transforms = []
         if self.args.enable_musan:
             logging.info("Enable MUSAN")
             transforms.append(
-                CutMix(
-                    cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
-                )
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
             )
         else:
             logging.info("Disable MUSAN")
@@ -244,9 +237,7 @@ class WenetSpeechAsrDataModule:
         input_transforms = []
         if self.args.enable_spec_aug:
             logging.info("Enable SpecAugment")
-            logging.info(
-                f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
-            )
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
             # Set the value of num_frame_masks according to Lhotse's version.
             # In different Lhotse's versions, the default of num_frame_masks is
             # different.
@@ -289,9 +280,7 @@ class WenetSpeechAsrDataModule:
             # Drop feats to be on the safe side.
             train = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 input_transforms=input_transforms,
                 return_cuts=self.args.return_cuts,
             )
@@ -348,9 +337,7 @@ class WenetSpeechAsrDataModule:
         if self.args.on_the_fly_feats:
             validate = K2SpeechRecognitionDataset(
                 cut_transforms=transforms,
-                input_strategy=OnTheFlyFeatures(
-                    Fbank(FbankConfig(num_mel_bins=80))
-                ),
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
                 return_cuts=self.args.return_cuts,
             )
         else:
@@ -358,24 +345,18 @@ class WenetSpeechAsrDataModule:
                 cut_transforms=transforms,
                 return_cuts=self.args.return_cuts,
             )
+
         valid_sampler = DynamicBucketingSampler(
             cuts_valid,
             max_duration=self.args.max_duration,
-            rank=0,
-            world_size=1,
             shuffle=False,
         )
         logging.info("About to create dev dataloader")
 
-        from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
-
-        dev_iter_dataset = IterableDatasetWrapper(
-            dataset=validate,
-            sampler=valid_sampler,
-        )
         valid_dl = DataLoader(
-            dev_iter_dataset,
+            validate,
             batch_size=None,
+            sampler=valid_sampler,
             num_workers=self.args.num_workers,
             persistent_workers=False,
         )
@@ -393,19 +374,13 @@ class WenetSpeechAsrDataModule:
         sampler = DynamicBucketingSampler(
             cuts,
             max_duration=self.args.max_duration,
-            rank=0,
-            world_size=1,
             shuffle=False,
         )
-        from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
 
-        test_iter_dataset = IterableDatasetWrapper(
-            dataset=test,
-            sampler=sampler,
-        )
         test_dl = DataLoader(
-            test_iter_dataset,
+            test,
             batch_size=None,
+            sampler=sampler,
             num_workers=self.args.num_workers,
         )
         return test_dl
@@ -414,8 +389,7 @@ class WenetSpeechAsrDataModule:
     def train_cuts(self) -> CutSet:
         logging.info("About to get train cuts")
         cuts_train = load_manifest_lazy(
-            self.args.manifest_dir
-            / f"cuts_{self.args.training_subset}.jsonl.gz"
+            self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
         )
         return cuts_train
 
@@ -427,13 +401,9 @@ class WenetSpeechAsrDataModule:
     @lru_cache()
     def test_net_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_NET cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
 
     @lru_cache()
     def test_meeting_cuts(self) -> List[CutSet]:
         logging.info("About to get TEST_MEETING cuts")
-        return load_manifest_lazy(
-            self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
-        )
+        return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
index f0c9bebec..bdd1f27bc 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -114,11 +114,7 @@ from beam_search import (
 from train import get_params, get_transducer_model
 
 from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
-from icefall.checkpoint import (
-    average_checkpoints,
-    find_checkpoints,
-    load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.lexicon import Lexicon
 from icefall.utils import (
     AttributeDict,
@@ -252,8 +248,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -328,9 +323,7 @@ def decode_one_batch(
     supervisions = batch["supervisions"]
     feature_lens = supervisions["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
     hyps = []
 
     if params.decoding_method == "fast_beam_search":
@@ -389,10 +382,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -515,9 +505,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -528,18 +516,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -549,10 +533,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -663,96 +644,25 @@ def main():
             )
             decoding_graph.scores *= params.ngram_lm_scale
         else:
-            decoding_graph = k2.trivial_graph(
-                params.vocab_size - 1, device=device
-            )
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
     else:
         decoding_graph = None
 
     num_param = sum([p.numel() for p in model.parameters()])
     logging.info(f"Number of model parameters: {num_param}")
 
-    # Note: Please use "pip install webdataset==0.1.103"
-    # for installing the webdataset.
-    import glob
-    import os
-
-    from lhotse import CutSet
-    from lhotse.dataset.webdataset import export_to_webdataset
-
     # we need cut ids to display recognition results.
     args.return_cuts = True
     wenetspeech = WenetSpeechAsrDataModule(args)
 
-    dev = "dev"
-    test_net = "test_net"
-    test_meeting = "test_meeting"
+    dev_cuts = wenetspeech.valid_cuts()
+    dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
 
-    if not os.path.exists(f"{dev}/shared-0.tar"):
-        os.makedirs(dev)
-        dev_cuts = wenetspeech.valid_cuts()
-        export_to_webdataset(
-            dev_cuts,
-            output_path=f"{dev}/shared-%d.tar",
-            shard_size=300,
-        )
+    test_net_cuts = wenetspeech.test_net_cuts()
+    test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
 
-    if not os.path.exists(f"{test_net}/shared-0.tar"):
-        os.makedirs(test_net)
-        test_net_cuts = wenetspeech.test_net_cuts()
-        export_to_webdataset(
-            test_net_cuts,
-            output_path=f"{test_net}/shared-%d.tar",
-            shard_size=300,
-        )
-
-    if not os.path.exists(f"{test_meeting}/shared-0.tar"):
-        os.makedirs(test_meeting)
-        test_meeting_cuts = wenetspeech.test_meeting_cuts()
-        export_to_webdataset(
-            test_meeting_cuts,
-            output_path=f"{test_meeting}/shared-%d.tar",
-            shard_size=300,
-        )
-
-    dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
-    ]
-    cuts_dev_webdataset = CutSet.from_webdataset(
-        dev_shards,
-        split_by_worker=True,
-        split_by_node=True,
-        shuffle_shards=True,
-    )
-
-    test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
-    ]
-    cuts_test_net_webdataset = CutSet.from_webdataset(
-        test_net_shards,
-        split_by_worker=True,
-        split_by_node=True,
-        shuffle_shards=True,
-    )
-
-    test_meeting_shards = [
-        str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
-    ]
-    cuts_test_meeting_webdataset = CutSet.from_webdataset(
-        test_meeting_shards,
-        split_by_worker=True,
-        split_by_node=True,
-        shuffle_shards=True,
-    )
-
-    dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
-    test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
-    test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
+    test_meeting_cuts = wenetspeech.test_meeting_cuts()
+    test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
 
     test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
     test_dl = [dev_dl, test_net_dl, test_meeting_dl]
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py
new file mode 100755
index 000000000..2e644ec2f
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py
@@ -0,0 +1,547 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless2/decode.py \
+    --epoch 84 \
+    --avg 25 \
+    --exp-dir ./pruned_transducer_stateless2/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless2/decode.py \
+    --epoch 84 \
+    --avg 25 \
+    --exp-dir ./pruned_transducer_stateless2/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless2/decode.py \
+    --epoch 84 \
+    --avg 25 \
+    --exp-dir ./pruned_transducer_stateless2/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless2/decode.py \
+    --epoch 84 \
+    --avg 25 \
+    --exp-dir ./pruned_transducer_stateless2/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 4 \
+    --max-contexts 4 \
+    --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from aishell import AishellAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from finetune import get_params, get_transducer_model
+
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    write_error_stats,
+)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="The lang dir",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=1,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    token_table: k2.SymbolTable,
+    batch: dict,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      token_table:
+        It maps token ID to a string.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+    else:
+        hyp_tokens = []
+        batch_size = encoder_out.size(0)
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyp_tokens.append(hyp)
+
+    hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif params.decoding_method == "fast_beam_search":
+        return {
+            (
+                f"beam_{params.beam}_"
+                f"max_contexts_{params.max_contexts}_"
+                f"max_states_{params.max_states}"
+            ): hyps
+        }
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    token_table: k2.SymbolTable,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      token_table:
+        It maps a token ID to a string.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            token_table=token_table,
+            decoding_graph=decoding_graph,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+        # we compute CER for aishell dataset.
+        results_char = []
+        for res in results:
+            results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results_char, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    AishellAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lang_dir = Path(args.lang_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    params.blank_id = 0
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if params.iter > 0:
+        filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+            : params.avg
+        ]
+        if len(filenames) == 0:
+            raise ValueError(
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+            )
+        elif len(filenames) < params.avg:
+            raise ValueError(
+                f"Not enough checkpoints ({len(filenames)}) found for"
+                f" --iter {params.iter}, --avg {params.avg}"
+            )
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+    elif params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if i >= 1:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+
+    model.to(device)
+    model.eval()
+
+    if params.decoding_method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    aishell = AishellAsrDataModule(args)
+    test_cuts = aishell.test_cuts()
+    dev_cuts = aishell.valid_cuts()
+    test_dl = aishell.test_dataloaders(test_cuts)
+    dev_dl = aishell.test_dataloaders(dev_cuts)
+
+    test_sets = ["test", "dev"]
+    test_dls = [test_dl, dev_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dls):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            token_table=lexicon.token_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 933642a0f..8c4fbdd47 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -205,8 +205,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     return parser
@@ -468,13 +467,9 @@ def export_joiner_model_onnx(
 
         - projected_decoder_out: a tensor of shape (N, joiner_dim)
     """
-    encoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_encoder_proj.onnx"
-    )
+    encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
 
-    decoder_proj_filename = str(joiner_filename).replace(
-        ".onnx", "_decoder_proj.onnx"
-    )
+    decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
 
     encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
     decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
@@ -645,9 +640,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py
new file mode 100755
index 000000000..e703100a9
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py
@@ -0,0 +1,1050 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Xiaoyu Yang,
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless2/finetune.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless2/exp \
+  --full-libri 1 \
+  --do-finetune 1 \
+  --max-duration 100
+
+"""
+
+
+import argparse
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from aishell import AishellAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import save_checkpoint_with_global_batch_idx
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument("--do-finetune", type=str2bool, default=False)
+
+    parser.add_argument(
+        "--init-modules",
+        type=str,
+        default=None,
+        help="""
+        Modules to be initialized. It matches all parameters starting with
+        a specific key. The keys are given with Comma seperated. If None,
+        all modules will be initialised. For example, if you only want to
+        initialise all parameters staring with "encoder", use "encoder";
+        if you want to initialise parameters starting with encoder or decoder,
+        use "encoder,joiner".
+        """,
+    )
+
+    parser.add_argument(
+        "--finetune-ckpt",
+        type=str,
+        default=None,
+        help="Fine-tuning from which checkpoint (a path to a .pt file)",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=0,
+        help="""Resume training from from this epoch.
+        If it is positive, it will load checkpoint from
+        pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless2/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=str,
+        default="data/lang_char",
+        help="""The lang dir
+        It contains language related input files such as
+        "lexicon.txt"
+        """,
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.0001,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=100000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. During fine-tuning, we set this very large so that the
+        learning rate slowly decays with number of batches. You may tune
+        its value by yourself.
+        """,
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=100,
+        help="""Number of epochs that affects how rapidly the learning rate
+        decreases. During fine-tuning, we set this very large so that the
+        learning rate slowly decays with number of batches. You may tune
+        its value by yourself.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=1,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=20,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--valid-interval",
+        type=int,
+        default=3000,
+        help="""When training_subset is L, set the valid_interval to 3000.
+        When training_subset is M, set the valid_interval to 1000.
+        When training_subset is S, set the valid_interval to 400.
+        """,
+    )
+
+    parser.add_argument(
+        "--model-warm-step",
+        type=int,
+        default=3000,
+        help="""When training_subset is L, set the model_warm_step to 3000.
+        When training_subset is M, set the model_warm_step to 500.
+        When training_subset is S, set the model_warm_step to 100.
+        """,
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+    Explanation of options saved in `params`:
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+        - best_train_epoch: It is the epoch that has the best training loss.
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+        - subsampling_factor:  The subsampling factor for the model.
+        - encoder_dim: Hidden dim for multi-head attention model.
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            "encoder_dim": 512,
+            "nhead": 8,
+            "dim_feedforward": 2048,
+            "num_encoder_layers": 12,
+            # parameters for decoder
+            "decoder_dim": 512,
+            # parameters for joiner
+            "joiner_dim": 512,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is positive, it will load the checkpoint from
+    `params.start_epoch - 1`.
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 0:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def load_model_params(
+    ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+    """Load model params from checkpoint
+
+    Args:
+        ckpt (str): Path to the checkpoint
+        model (nn.Module): model to be loaded
+
+    """
+    logging.info(f"Loading checkpoint from {ckpt}")
+    checkpoint = torch.load(ckpt, map_location="cpu")
+
+    # if module list is empty, load the whole model from ckpt
+    if not init_modules:
+        if next(iter(checkpoint["model"])).startswith("module."):
+            logging.info("Loading checkpoint saved by DDP")
+
+            dst_state_dict = model.state_dict()
+            src_state_dict = checkpoint["model"]
+            for key in dst_state_dict.keys():
+                src_key = "{}.{}".format("module", key)
+                dst_state_dict[key] = src_state_dict.pop(src_key)
+            assert len(src_state_dict) == 0
+            model.load_state_dict(dst_state_dict, strict=strict)
+        else:
+            model.load_state_dict(checkpoint["model"], strict=strict)
+    else:
+        src_state_dict = checkpoint["model"]
+        dst_state_dict = model.state_dict()
+        for module in init_modules:
+            logging.info(f"Loading parameters starting with prefix {module}")
+            src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
+            dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
+            assert set(src_keys) == set(dst_keys)  # two sets should match exactly
+            for key in src_keys:
+                dst_state_dict[key] = src_state_dict.pop(key)
+
+        model.load_state_dict(dst_state_dict, strict=strict)
+
+    return None
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+
+    y = graph_compiler.texts_to_ids(texts)
+    if isinstance(y, list):
+        y = k2.RaggedTensor(y).to(device)
+    else:
+        y = y.to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+        )
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            graph_compiler=graph_compiler,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                graph_compiler=graph_compiler,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    lexicon = Lexicon(params.lang_dir)
+    graph_compiler = CharCtcTrainingGraphCompiler(
+        lexicon=lexicon,
+        device=device,
+    )
+
+    params.blank_id = lexicon.token_table[""]
+    params.vocab_size = max(lexicon.tokens) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # load model parameters for model fine-tuning
+    if params.do_finetune:
+        modules = params.init_modules.split(",") if params.init_modules else None
+        checkpoints = load_model_params(
+            ckpt=params.finetune_ckpt, model=model, init_modules=modules
+        )
+    else:
+        assert params.start_epoch > 0, params.start_epoch
+        checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+    model.device = device
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    aishell = AishellAsrDataModule(args)
+    train_dl = aishell.train_dataloaders(aishell.train_cuts())
+    valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            graph_compiler=graph_compiler,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            graph_compiler=graph_compiler,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = graph_compiler.texts_to_ids(supervisions["text"])
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    graph_compiler: CharCtcTrainingGraphCompiler,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+            # (i.e. are not remembered by the decaying-average in adam), because
+            # we want to avoid these params being subject to shrinkage in adam.
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    graph_compiler=graph_compiler,
+                    batch=batch,
+                    is_training=True,
+                    warmup=0.0 if params.start_epoch == 1 else 1.0,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+            raise
+
+
+def main():
+    parser = get_parser()
+    AishellAsrDataModule.add_arguments(
+        parser
+    )  # you may replace this with your own dataset
+    add_finetune_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index e5cc47bfe..f90dd2b43 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -145,10 +145,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -331,9 +330,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
index c396c50ef..a46ff5a07 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py
@@ -219,9 +219,7 @@ def test_joiner(
         )
 
         # Now test encoder_proj
-        joiner_encoder_proj_inputs = {
-            encoder_proj_input_name: encoder_out.numpy()
-        }
+        joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()}
         joiner_encoder_proj_out = joiner_encoder_proj_session.run(
             [encoder_proj_output_name], joiner_encoder_proj_inputs
         )[0]
@@ -230,16 +228,10 @@ def test_joiner(
         torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out)
         assert torch.allclose(
             joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5
-        ), (
-            (joiner_encoder_proj_out - torch_joiner_encoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max())
 
         # Now test decoder_proj
-        joiner_decoder_proj_inputs = {
-            decoder_proj_input_name: decoder_out.numpy()
-        }
+        joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()}
         joiner_decoder_proj_out = joiner_decoder_proj_session.run(
             [decoder_proj_output_name], joiner_decoder_proj_inputs
         )[0]
@@ -248,11 +240,7 @@ def test_joiner(
         torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out)
         assert torch.allclose(
             joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5
-        ), (
-            (joiner_decoder_proj_out - torch_joiner_decoder_proj_out)
-            .abs()
-            .max()
-        )
+        ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max())
 
 
 @torch.no_grad()
@@ -304,9 +292,7 @@ def main():
 
 if __name__ == "__main__":
     torch.manual_seed(20220727)
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
index 3770fbbb4..9e34b4427 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py
@@ -149,10 +149,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -200,11 +199,7 @@ def greedy_search(
 
     projected_encoder_out = joiner_encoder_proj.run(
         [joiner_encoder_proj.get_outputs()[0].name],
-        {
-            joiner_encoder_proj.get_inputs()[
-                0
-            ].name: packed_encoder_out.data.numpy()
-        },
+        {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
     )[0]
 
     blank_id = 0  # hard-code to 0
@@ -389,9 +384,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index 9a549efd9..bc499f3dd 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -158,8 +158,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +251,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +276,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
index d3cc7c9c9..48b347b64 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
@@ -115,9 +115,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def get_parser():
@@ -219,8 +217,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -243,8 +240,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -590,22 +586,15 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -762,9 +751,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -864,7 +851,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -874,15 +861,41 @@ def run(rank, world_size, args):
     valid_cuts = wenetspeech.valid_cuts()
 
     def remove_short_and_long_utt(c: Cut):
-        # Keep only utterances with duration between 1 second and 15.0 seconds
+        # Keep only utterances with duration between 1 second and 10 seconds
         #
-        # Caution: There is a reason to select 15.0 here. Please see
+        # Caution: There is a reason to select 10.0 here. Please see
         # ../local/display_manifest_statistics.py
         #
         # You should use ../local/display_manifest_statistics.py to get
         # an utterance duration distribution for your dataset to select
         # the threshold
-        return 1.0 <= c.duration <= 15.0
+        if c.duration < 1.0 or c.duration > 10.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = c.supervisions[0].text.replace(" ", "")
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
 
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
index dd27c17f0..23a877b2f 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
@@ -210,10 +210,7 @@ class Conformer(EncoderInterface):
           (num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
           NOTE: the returned tensors are on the given device.
         """
-        if (
-            len(self._init_state) == 2
-            and self._init_state[0].size(1) == left_context
-        ):
+        if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context:
             # Note: It is OK to share the init state as it is
             # not going to be modified by the model
             return self._init_state
@@ -433,9 +430,7 @@ class ConformerEncoderLayer(nn.Module):
 
         self.d_model = d_model
 
-        self.self_attn = RelPositionMultiheadAttention(
-            d_model, nhead, dropout=0.0
-        )
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
 
         self.feed_forward = nn.Sequential(
             ScaledLinear(d_model, dim_feedforward),
@@ -453,9 +448,7 @@ class ConformerEncoderLayer(nn.Module):
             ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
         )
 
-        self.conv_module = ConvolutionModule(
-            d_model, cnn_module_kernel, causal=causal
-        )
+        self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal)
 
         self.norm_final = BasicNorm(d_model)
 
@@ -520,9 +513,7 @@ class ConformerEncoderLayer(nn.Module):
         src = src + self.dropout(src_att)
 
         # convolution module
-        conv, _ = self.conv_module(
-            src, src_key_padding_mask=src_key_padding_mask
-        )
+        conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
         src = src + self.dropout(conv)
 
         # feed forward module
@@ -766,9 +757,7 @@ class RelPositionalEncoding(torch.nn.Module):
         max_len: Maximum input length.
     """
 
-    def __init__(
-        self, d_model: int, dropout_rate: float, max_len: int = 5000
-    ) -> None:
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
         """Construct an PositionalEncoding object."""
         super(RelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -784,9 +773,7 @@ class RelPositionalEncoding(torch.nn.Module):
             # the length of self.pe is 2 * input_len - 1
             if self.pe.size(1) >= x_size_1 * 2 - 1:
                 # Note: TorchScript doesn't implement operator== for torch.Device
-                if self.pe.dtype != x.dtype or str(self.pe.device) != str(
-                    x.device
-                ):
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
                     self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                 return
         # Suppose `i` means to the position of query vector and `j` means the
@@ -979,20 +966,32 @@ class RelPositionMultiheadAttention(nn.Module):
         (batch_size, num_heads, time1, n) = x.shape
 
         time2 = time1 + left_context
-        assert (
-            n == left_context + 2 * time1 - 1
-        ), f"{n} == {left_context} + 2 * {time1} - 1"
+        if not torch.jit.is_tracing():
+            assert (
+                n == left_context + 2 * time1 - 1
+            ), f"{n} == {left_context} + 2 * {time1} - 1"
 
-        # Note: TorchScript requires explicit arg for stride()
-        batch_stride = x.stride(0)
-        head_stride = x.stride(1)
-        time1_stride = x.stride(2)
-        n_stride = x.stride(3)
-        return x.as_strided(
-            (batch_size, num_heads, time1, time2),
-            (batch_stride, head_stride, time1_stride - n_stride, n_stride),
-            storage_offset=n_stride * (time1 - 1),
-        )
+        if torch.jit.is_tracing():
+            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+            cols = torch.arange(time2)
+            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+            indexes = rows + cols
+
+            x = x.reshape(-1, n)
+            x = torch.gather(x, dim=1, index=indexes)
+            x = x.reshape(batch_size, num_heads, time1, time2)
+            return x
+        else:
+            # Note: TorchScript requires explicit arg for stride()
+            batch_stride = x.stride(0)
+            head_stride = x.stride(1)
+            time1_stride = x.stride(2)
+            n_stride = x.stride(3)
+            return x.as_strided(
+                (batch_size, num_heads, time1, time2),
+                (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+                storage_offset=n_stride * (time1 - 1),
+            )
 
     def multi_head_attention_forward(
         self,
@@ -1073,9 +1072,9 @@ class RelPositionMultiheadAttention(nn.Module):
 
         if torch.equal(query, key) and torch.equal(key, value):
             # self-attention
-            q, k, v = nn.functional.linear(
-                query, in_proj_weight, in_proj_bias
-            ).chunk(3, dim=-1)
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
 
         elif torch.equal(key, value):
             # encoder-decoder attention
@@ -1144,31 +1143,22 @@ class RelPositionMultiheadAttention(nn.Module):
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0)
                 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
-                    raise RuntimeError(
-                        "The size of the 2D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
             elif attn_mask.dim() == 3:
                 if list(attn_mask.size()) != [
                     bsz * num_heads,
                     query.size(0),
                     key.size(0),
                 ]:
-                    raise RuntimeError(
-                        "The size of the 3D attn_mask is not correct."
-                    )
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
             else:
                 raise RuntimeError(
-                    "attn_mask's dimension {} is not supported".format(
-                        attn_mask.dim()
-                    )
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
                 )
             # attn_mask's dim is 3 now.
 
         # convert ByteTensor key_padding_mask to bool
-        if (
-            key_padding_mask is not None
-            and key_padding_mask.dtype == torch.uint8
-        ):
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
             warnings.warn(
                 "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
             )
@@ -1208,23 +1198,15 @@ class RelPositionMultiheadAttention(nn.Module):
         # first compute matrix a and matrix c
         # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
         k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
-        matrix_ac = torch.matmul(
-            q_with_bias_u, k
-        )  # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
 
         # compute matrix b and matrix d
-        matrix_bd = torch.matmul(
-            q_with_bias_v, p
-        )  # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
         matrix_bd = self.rel_shift(matrix_bd, left_context)
 
-        attn_output_weights = (
-            matrix_ac + matrix_bd
-        )  # (batch, head, time1, time2)
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
 
-        attn_output_weights = attn_output_weights.view(
-            bsz * num_heads, tgt_len, -1
-        )
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
 
         assert list(attn_output_weights.size()) == [
             bsz * num_heads,
@@ -1265,21 +1247,17 @@ class RelPositionMultiheadAttention(nn.Module):
         ):
             if attn_mask.size(0) != 1:
                 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
-                combined_mask = attn_mask | key_padding_mask.unsqueeze(
-                    1
-                ).unsqueeze(2)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
             else:
                 # attn_mask.shape == (1, tgt_len, src_len)
-                combined_mask = attn_mask.unsqueeze(
-                    0
-                ) | key_padding_mask.unsqueeze(1).unsqueeze(2)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
 
             attn_output_weights = attn_output_weights.view(
                 bsz, num_heads, tgt_len, src_len
             )
-            attn_output_weights = attn_output_weights.masked_fill(
-                combined_mask, 0.0
-            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
             attn_output_weights = attn_output_weights.view(
                 bsz * num_heads, tgt_len, src_len
             )
@@ -1291,13 +1269,9 @@ class RelPositionMultiheadAttention(nn.Module):
         attn_output = torch.bmm(attn_output_weights, v)
         assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
         attn_output = (
-            attn_output.transpose(0, 1)
-            .contiguous()
-            .view(tgt_len, bsz, embed_dim)
-        )
-        attn_output = nn.functional.linear(
-            attn_output, out_proj_weight, out_proj_bias
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
         )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
 
         if need_weights:
             # average attention weights over heads
@@ -1430,16 +1404,12 @@ class ConvolutionModule(nn.Module):
                 # manualy padding self.lorder zeros to the left
                 x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
             else:
-                assert (
-                    not self.training
-                ), "Cache should be None in training time"
+                assert not self.training, "Cache should be None in training time"
                 assert cache.size(0) == self.lorder
                 x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
                 if right_context > 0:
                     cache = x.permute(2, 0, 1)[
-                        -(self.lorder + right_context) : (  # noqa
-                            -right_context
-                        ),
+                        -(self.lorder + right_context) : (-right_context),  # noqa
                         ...,
                     ]
                 else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 344e31283..46ba6b005 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -2,6 +2,7 @@
 #
 # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
 # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -91,6 +92,22 @@ When training with the L subset, the streaming usage:
         --causal-convolution 1 \
         --decode-chunk-size 16 \
         --left-context 64
+        
+(4) modified beam search with RNNLM shallow fusion
+./pruned_transducer_stateless5/decode.py \
+    --epoch 35 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search_lm_shallow_fusion \
+    --beam-size 4 \
+    --lm-type rnn \
+    --lm-scale 0.3 \
+    --lm-exp-dir /path/to/LM \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1
 """
 
 
@@ -111,9 +128,12 @@ from beam_search import (
     greedy_search,
     greedy_search_batch,
     modified_beam_search,
+    modified_beam_search_lm_shallow_fusion,
+    modified_beam_search_LODR,
 )
 from train import add_model_arguments, get_params, get_transducer_model
 
+from icefall import LmScorer, NgramLm
 from icefall.checkpoint import (
     average_checkpoints,
     average_checkpoints_with_averaged_model,
@@ -224,6 +244,16 @@ def get_parser():
         Used only when --decoding-method is fast_beam_search""",
     )
 
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
     parser.add_argument(
         "--max-contexts",
         type=int,
@@ -244,8 +274,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     parser.add_argument(
         "--max-sym-per-frame",
@@ -278,6 +307,50 @@ def get_parser():
         help="left context can be seen during decoding (in frames after subsampling)",
     )
 
+    parser.add_argument(
+        "--use-shallow-fusion",
+        type=str2bool,
+        default=False,
+        help="""Use neural network LM for shallow fusion.
+        If you want to use LODR, you will also need to set this to true
+        """,
+    )
+
+    parser.add_argument(
+        "--lm-type",
+        type=str,
+        default="rnn",
+        help="Type of NN lm",
+        choices=["rnn", "transformer"],
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.3,
+        help="""The scale of the neural network LM
+        Used only when `--use-shallow-fusion` is set to True.
+        """,
+    )
+
+    parser.add_argument(
+        "--tokens-ngram",
+        type=int,
+        default=3,
+        help="""Token Ngram used for rescoring.
+            Used only when the decoding method is
+            modified_beam_search_ngram_rescoring, or LODR
+            """,
+    )
+
+    parser.add_argument(
+        "--backoff-id",
+        type=int,
+        default=500,
+        help="""ID of the backoff symbol.
+                Used only when the decoding method is
+                modified_beam_search_ngram_rescoring""",
+    )
     add_model_arguments(parser)
 
     return parser
@@ -289,6 +362,9 @@ def decode_one_batch(
     lexicon: Lexicon,
     batch: dict,
     decoding_graph: Optional[k2.Fsa] = None,
+    ngram_lm: Optional[NgramLm] = None,
+    ngram_lm_scale: float = 1.0,
+    LM: Optional[LmScorer] = None,
 ) -> Dict[str, List[List[str]]]:
     """Decode one batch and return the result in a dict. The dict has the
     following format:
@@ -342,9 +418,7 @@ def decode_one_batch(
             simulate_streaming=True,
         )
     else:
-        encoder_out, encoder_out_lens = model.encoder(
-            x=feature, x_lens=feature_lens
-        )
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
 
@@ -360,10 +434,7 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -380,6 +451,28 @@ def decode_one_batch(
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+        hyp_tokens = modified_beam_search_lm_shallow_fusion(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            LM=LM,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+    elif params.decoding_method == "modified_beam_search_LODR":
+        hyp_tokens = modified_beam_search_LODR(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            LODR_lm=ngram_lm,
+            LODR_lm_scale=ngram_lm_scale,
+            LM=LM,
+        )
+        for i in range(encoder_out.size(0)):
+            hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
     else:
         batch_size = encoder_out.size(0)
 
@@ -425,6 +518,9 @@ def decode_dataset(
     model: nn.Module,
     lexicon: Lexicon,
     decoding_graph: Optional[k2.Fsa] = None,
+    ngram_lm: Optional[NgramLm] = None,
+    ngram_lm_scale: float = 1.0,
+    LM: Optional[LmScorer] = None,
 ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
     """Decode dataset.
 
@@ -438,6 +534,8 @@ def decode_dataset(
       decoding_graph:
         The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
         only when --decoding_method is fast_beam_search.
+      LM:
+        A neural network LM, used during shallow fusion
     Returns:
       Return a dict, whose key may be "greedy_search" if greedy search
       is used, or it may be "beam_7" if beam size of 7 is used.
@@ -455,7 +553,7 @@ def decode_dataset(
     if params.decoding_method == "greedy_search":
         log_interval = 100
     else:
-        log_interval = 2
+        log_interval = 20
 
     results = defaultdict(list)
     for batch_idx, batch in enumerate(dl):
@@ -469,6 +567,9 @@ def decode_dataset(
             lexicon=lexicon,
             decoding_graph=decoding_graph,
             batch=batch,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
+            LM=LM,
         )
 
         for name, hyps in hyps_dict.items():
@@ -484,9 +585,7 @@ def decode_dataset(
         if batch_idx % log_interval == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -497,18 +596,14 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         results = sorted(results)
         store_transcripts(filename=recog_path, texts=results)
         logging.info(f"The transcripts are stored in {recog_path}")
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -518,10 +613,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -539,6 +631,7 @@ def save_results(
 def main():
     parser = get_parser()
     WenetSpeechAsrDataModule.add_arguments(parser)
+    LmScorer.add_arguments(parser)
     args = parser.parse_args()
     args.exp_dir = Path(args.exp_dir)
 
@@ -550,6 +643,8 @@ def main():
         "beam_search",
         "fast_beam_search",
         "modified_beam_search",
+        "modified_beam_search_lm_shallow_fusion",
+        "modified_beam_search_LODR",
     )
     params.res_dir = params.exp_dir / params.decoding_method
 
@@ -564,6 +659,22 @@ def main():
         params.suffix += f"-context-{params.context_size}"
         params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
 
+    if "ngram" in params.decoding_method:
+        params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    if params.use_shallow_fusion:
+        if params.lm_type == "rnn":
+            params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
+        elif params.lm_type == "transformer":
+            params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
+
+        if "LODR" in params.decoding_method:
+            params.suffix += (
+                f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+            )
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
     setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
     logging.info("Decoding started")
 
@@ -573,6 +684,7 @@ def main():
 
     logging.info(f"Device: {device}")
 
+    # import pdb; pdb.set_trace()
     lexicon = Lexicon(params.lang_dir)
     params.blank_id = lexicon.token_table[""]
     params.vocab_size = max(lexicon.tokens) + 1
@@ -589,9 +701,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -618,9 +730,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -667,6 +779,37 @@ def main():
     model.to(device)
     model.eval()
     model.device = device
+    # only load N-gram LM when needed
+    if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
+        lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+        logging.info(f"lm filename: {lm_filename}")
+        ngram_lm = NgramLm(
+            str(params.lang_dir / lm_filename),
+            backoff_id=params.backoff_id,
+            is_binary=False,
+        )
+        logging.info(f"num states: {ngram_lm.lm.num_states}")
+        ngram_lm_scale = params.ngram_lm_scale
+    else:
+        ngram_lm = None
+        ngram_lm_scale = None
+
+    # import pdb; pdb.set_trace()
+    # only load the neural network LM if doing shallow fusion
+    if params.use_shallow_fusion:
+        LM = LmScorer(
+            lm_type=params.lm_type,
+            params=params,
+            device=device,
+            lm_scale=params.lm_scale,
+        )
+        LM.to(device)
+        LM.eval()
+
+        num_param = sum([p.numel() for p in LM.parameters()])
+        logging.info(f"Number of model parameters: {num_param}")
+    else:
+        LM = None
 
     if params.decoding_method == "fast_beam_search":
         decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
@@ -676,87 +819,18 @@ def main():
     num_param = sum([p.numel() for p in model.parameters()])
     logging.info(f"Number of model parameters: {num_param}")
 
-    # Note: Please use "pip install webdataset==0.1.103"
-    # for installing the webdataset.
-    import glob
-    import os
-
-    from lhotse import CutSet
-    from lhotse.dataset.webdataset import export_to_webdataset
-
     # we need cut ids to display recognition results.
     args.return_cuts = True
     wenetspeech = WenetSpeechAsrDataModule(args)
 
-    dev = "dev"
-    test_net = "test_net"
-    test_meeting = "test_meeting"
+    dev_cuts = wenetspeech.valid_cuts()
+    dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
 
-    if not os.path.exists(f"{dev}/shared-0.tar"):
-        os.makedirs(dev)
-        dev_cuts = wenetspeech.valid_cuts()
-        export_to_webdataset(
-            dev_cuts,
-            output_path=f"{dev}/shared-%d.tar",
-            shard_size=300,
-        )
+    test_net_cuts = wenetspeech.test_net_cuts()
+    test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
 
-    if not os.path.exists(f"{test_net}/shared-0.tar"):
-        os.makedirs(test_net)
-        test_net_cuts = wenetspeech.test_net_cuts()
-        export_to_webdataset(
-            test_net_cuts,
-            output_path=f"{test_net}/shared-%d.tar",
-            shard_size=300,
-        )
-
-    if not os.path.exists(f"{test_meeting}/shared-0.tar"):
-        os.makedirs(test_meeting)
-        test_meeting_cuts = wenetspeech.test_meeting_cuts()
-        export_to_webdataset(
-            test_meeting_cuts,
-            output_path=f"{test_meeting}/shared-%d.tar",
-            shard_size=300,
-        )
-
-    dev_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
-    ]
-    cuts_dev_webdataset = CutSet.from_webdataset(
-        dev_shards,
-        split_by_worker=True,
-        split_by_node=True,
-        shuffle_shards=True,
-    )
-
-    test_net_shards = [
-        str(path)
-        for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
-    ]
-    cuts_test_net_webdataset = CutSet.from_webdataset(
-        test_net_shards,
-        split_by_worker=True,
-        split_by_node=True,
-        shuffle_shards=True,
-    )
-
-    test_meeting_shards = [
-        str(path)
-        for path in sorted(
-            glob.glob(os.path.join(test_meeting, "shared-*.tar"))
-        )
-    ]
-    cuts_test_meeting_webdataset = CutSet.from_webdataset(
-        test_meeting_shards,
-        split_by_worker=True,
-        split_by_node=True,
-        shuffle_shards=True,
-    )
-
-    dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
-    test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
-    test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
+    test_meeting_cuts = wenetspeech.test_meeting_cuts()
+    test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
 
     test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
     test_dl = [dev_dl, test_net_dl, test_meeting_dl]
@@ -768,6 +842,9 @@ def main():
             model=model,
             lexicon=lexicon,
             decoding_graph=decoding_graph,
+            ngram_lm=ngram_lm,
+            ngram_lm_scale=ngram_lm_scale,
+            LM=LM,
         )
         save_results(
             params=params,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
index 386248554..e522943c0 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -75,9 +75,7 @@ class DecodeStream(object):
         # encoder.streaming_forward
         self.done_frames: int = 0
 
-        self.pad_length = (
-            params.right_context + 2
-        ) * params.subsampling_factor + 3
+        self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3
 
         if params.decoding_method == "greedy_search":
             self.hyp = [params.blank_id] * params.context_size
@@ -91,13 +89,11 @@ class DecodeStream(object):
             )
         elif params.decoding_method == "fast_beam_search":
             # The rnnt_decoding_stream for fast_beam_search.
-            self.rnnt_decoding_stream: k2.RnntDecodingStream = (
-                k2.RnntDecodingStream(decoding_graph)
+            self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
+                decoding_graph
             )
         else:
-            raise ValueError(
-                f"Unsupported decoding method: {params.decoding_method}"
-            )
+            raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     @property
     def done(self) -> bool:
@@ -126,13 +122,10 @@ class DecodeStream(object):
         """Consume chunk_size frames of features"""
         chunk_length = chunk_size + self.pad_length
 
-        ret_length = min(
-            self.num_frames - self.num_processed_frames, chunk_length
-        )
+        ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
 
         ret_features = self.features[
-            self.num_processed_frames : self.num_processed_frames  # noqa
-            + ret_length
+            self.num_processed_frames : self.num_processed_frames + ret_length  # noqa
         ]
 
         self.num_processed_frames += chunk_size
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
old mode 100644
new mode 100755
index d0a7fd69f..cb541070e
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -74,6 +74,7 @@ import logging
 from pathlib import Path
 
 import torch
+from scaling_converter import convert_scaled_to_non_scaled
 from train import add_model_arguments, get_params, get_transducer_model
 
 from icefall.checkpoint import average_checkpoints, load_checkpoint
@@ -131,8 +132,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
     add_model_arguments(parser)
 
@@ -185,6 +185,7 @@ def main():
         # it here.
         # Otherwise, one of its arguments is a ragged tensor and is not
         # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
         model.__class__.forward = torch.jit.ignore(model.__class__.forward)
         logging.info("Using torch.jit.script")
         model = torch.jit.script(model)
@@ -201,9 +202,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..d13a1e063
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1b064c874..1cac20435 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -157,8 +157,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -189,10 +188,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -253,9 +251,7 @@ def main():
     features = fbank(waves)
     feature_lengths = [f.size(0) for f in features]
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     feature_lengths = torch.tensor(feature_lengths, device=device)
 
@@ -280,10 +276,7 @@ def main():
         )
         for i in range(encoder_out.size(0)):
             hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
-    elif (
-        params.decoding_method == "greedy_search"
-        and params.max_sym_per_frame == 1
-    ):
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
         hyp_tokens = greedy_search_batch(
             model=model,
             encoder_out=encoder_out,
@@ -335,9 +328,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..e58473a04
--- /dev/null
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
index 651aff6c9..810d94135 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -173,14 +173,10 @@ def modified_beam_search(
         log_probs_shape = k2.ragged.create_ragged_shape2(
             row_splits=row_splits, cached_tot_size=log_probs.numel()
         )
-        ragged_log_probs = k2.RaggedTensor(
-            shape=log_probs_shape, value=log_probs
-        )
+        ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
 
         for i in range(batch_size):
-            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
-                num_active_paths
-            )
+            topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
 
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index ff96c6487..3a4dc3cb8 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -201,8 +201,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -311,9 +310,7 @@ def decode_one_chunk(
     encoder_out = model.joiner.encoder_proj(encoder_out)
 
     if params.decoding_method == "greedy_search":
-        greedy_search(
-            model=model, encoder_out=encoder_out, streams=decode_streams
-        )
+        greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
     elif params.decoding_method == "fast_beam_search":
         processed_lens = processed_lens + encoder_out_lens
         fast_beam_search_one_best(
@@ -333,9 +330,7 @@ def decode_one_chunk(
             num_active_paths=params.num_active_paths,
         )
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
 
@@ -389,9 +384,7 @@ def decode_dataset(
     decode_results = []
     # Contain decode streams currently running.
     decode_streams = []
-    initial_states = model.encoder.get_init_state(
-        params.left_context, device=device
-    )
+    initial_states = model.encoder.get_init_state(params.left_context, device=device)
     for num, cut in enumerate(cuts):
         # each utterance has a DecodeStream.
         decode_stream = DecodeStream(
@@ -461,9 +454,7 @@ def decode_dataset(
     elif params.decoding_method == "modified_beam_search":
         key = f"num_active_paths_{params.num_active_paths}"
     else:
-        raise ValueError(
-            f"Unsupported decoding method: {params.decoding_method}"
-        )
+        raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
 
     return {key: decode_results}
 
@@ -475,9 +466,7 @@ def save_results(
 ):
     test_set_wers = dict()
     for key, results in results_dict.items():
-        recog_path = (
-            params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
         # sort results so we can easily compare the difference between two
         # recognition results
         results = sorted(results)
@@ -486,9 +475,7 @@ def save_results(
 
         # The following prints out WERs, per-word error statistics and aligned
         # ref/hyp pairs.
-        errs_filename = (
-            params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
-        )
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
         with open(errs_filename, "w") as f:
             wer = write_error_stats(
                 f, f"{test_set_name}-{key}", results, enable_log=True
@@ -498,10 +485,7 @@ def save_results(
         logging.info("Wrote detailed error stats to {}".format(errs_filename))
 
     test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
-    errs_info = (
-        params.res_dir
-        / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
-    )
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
     with open(errs_info, "w") as f:
         print("settings\tWER", file=f)
         for key, val in test_set_wers:
@@ -565,9 +549,9 @@ def main():
 
     if not params.use_averaged_model:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
@@ -594,9 +578,9 @@ def main():
             model.load_state_dict(average_checkpoints(filenames, device=device))
     else:
         if params.iter > 0:
-            filenames = find_checkpoints(
-                params.exp_dir, iteration=-params.iter
-            )[: params.avg + 1]
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
             if len(filenames) == 0:
                 raise ValueError(
                     f"No checkpoints found for"
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
index 2052e9da7..34a72be8f 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py
@@ -98,9 +98,7 @@ from icefall.env import get_env_info
 from icefall.lexicon import Lexicon
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
 
-LRSchedulerType = Union[
-    torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
 
 
 def add_model_arguments(parser: argparse.ArgumentParser):
@@ -260,8 +258,7 @@ def get_parser():
         "--initial-lr",
         type=float,
         default=0.003,
-        help="The initial learning rate.  This value should not need "
-        "to be changed.",
+        help="The initial learning rate.  This value should not need to be changed.",
     )
 
     parser.add_argument(
@@ -284,8 +281,7 @@ def get_parser():
         "--context-size",
         type=int,
         default=2,
-        help="The context size in the decoder. 1 means bigram; "
-        "2 means tri-gram",
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
     )
 
     parser.add_argument(
@@ -308,8 +304,7 @@ def get_parser():
         "--am-scale",
         type=float,
         default=0.0,
-        help="The scale to smooth the loss with am (output of encoder network)"
-        "part.",
+        help="The scale to smooth the loss with am (output of encoder network) part.",
     )
 
     parser.add_argument(
@@ -665,11 +660,7 @@ def compute_loss(
      warmup: a floating point value which increases throughout training;
         values >= 1.0 are fully warmed up and have all modules present.
     """
-    device = (
-        model.device
-        if isinstance(model, DDP)
-        else next(model.parameters()).device
-    )
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
     feature = batch["inputs"]
     # at entry, feature is (N, T, C)
     assert feature.ndim == 3
@@ -701,23 +692,16 @@ def compute_loss(
         # overwhelming the simple_loss and causing it to diverge,
         # in case it had not fully learned the alignment yet.
         pruned_loss_scale = (
-            0.0
-            if warmup < 1.0
-            else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
-        )
-        loss = (
-            params.simple_loss_scale * simple_loss
-            + pruned_loss_scale * pruned_loss
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
         )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
 
     assert loss.requires_grad == is_training
 
     info = MetricsTracker()
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-        info["frames"] = (
-            (feature_lens // params.subsampling_factor).sum().item()
-        )
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
 
     # Note: We use reduction=sum while computing the loss.
     info["loss"] = loss.detach().cpu().item()
@@ -841,9 +825,7 @@ def train_one_epoch(
             scaler.update()
             optimizer.zero_grad()
         except:  # noqa
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
         if params.print_diagnostics and batch_idx == 5:
@@ -901,9 +883,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -1016,7 +996,7 @@ def run(rank, world_size, args):
 
     if params.print_diagnostics:
         opts = diagnostics.TensorDiagnosticOptions(
-            2 ** 22
+            2**22
         )  # allow 4 megabytes per sub-module
         diagnostic = diagnostics.attach_diagnostics(model, opts)
 
@@ -1026,15 +1006,41 @@ def run(rank, world_size, args):
     valid_cuts = wenetspeech.valid_cuts()
 
     def remove_short_and_long_utt(c: Cut):
-        # Keep only utterances with duration between 1 second and 15.0 seconds
+        # Keep only utterances with duration between 1 second and 10 seconds
         #
-        # Caution: There is a reason to select 15.0 here. Please see
+        # Caution: There is a reason to select 10.0 here. Please see
         # ../local/display_manifest_statistics.py
         #
         # You should use ../local/display_manifest_statistics.py to get
         # an utterance duration distribution for your dataset to select
         # the threshold
-        return 1.0 <= c.duration <= 15.0
+        if c.duration < 1.0 or c.duration > 10.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = c.supervisions[0].text.replace(" ", "")
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
 
     train_cuts = train_cuts.filter(remove_short_and_long_utt)
 
@@ -1184,9 +1190,7 @@ def scan_pessimistic_batches_for_oom(
                     f"Failing criterion: {criterion} "
                     f"(={crit_values[criterion]}) ..."
                 )
-            display_and_save_batch(
-                batch, params=params, graph_compiler=graph_compiler
-            )
+            display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
             raise
 
 
diff --git a/egs/xbmu_amdo31/ASR/README.md b/egs/xbmu_amdo31/ASR/README.md
new file mode 100644
index 000000000..0a441d070
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/README.md
@@ -0,0 +1,16 @@
+# Introduction
+About the XBMU-AMDO31 corpus
+XBMU-AMDO31 is an open-source Amdo Tibetan speech corpus published by Northwest Minzu University.
+publicly available on https://huggingface.co/datasets/syzym/xbmu_amdo31
+
+XBMU-AMDO31 dataset is a speech recognition corpus of Amdo Tibetan dialect. 
+The open source corpus contains 31 hours of speech data and resources related 
+to build speech recognition systems,including transcribed texts and a Tibetan 
+pronunciation lexicon.
+(The lexicon is a Tibetan lexicon of the Lhasa dialect, which has been reused 
+for the Amdo dialect because of the uniformity of the Tibetan language)
+The dataset can be used to train a model for Amdo Tibetan Automatic Speech Recognition (ASR). 
+
+This recipe includes some different ASR models trained with XBMU-AMDO31.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/RESULTS.md b/egs/xbmu_amdo31/ASR/RESULTS.md
new file mode 100644
index 000000000..1bd9b2e2b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/RESULTS.md
@@ -0,0 +1,92 @@
+## Results
+
+### XBMU-AMDO31 BPE training result (Stateless Transducer)
+
+#### Pruned transducer stateless 5
+
+[./pruned_transducer_stateless5](./pruned_transducer_stateless5)
+
+It uses pruned RNN-T.
+
+A pre-trained model and decoding logs can be found at 
+
+You can use  to deploy it.
+
+Number of model parameters: 87801200, i.e., 87.8 M
+
+|                        | test | dev  | comment                               |
+|------------------------|------|------|---------------------------------------|
+| greedy search          | 11.06| 11.73| --epoch 28 --avg 23 --max-duration 600|
+| beam search            | 10.64| 11.42| --epoch 28 --avg 23 --max-duration 600|
+| modified beam search   | 10.57| 11.24| --epoch 28 --avg 23 --max-duration 600|
+
+
+Training command is:
+
+```bash
+cd egs/xbmu_amdo31/ASR
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0"
+
+./pruned_transducer_stateless5/train.py
+```
+
+**Caution**: It uses `--context-size=1`.
+
+
+The decoding command is:
+```bash
+for method in greedy_search beam_search modified_beam_search;
+do
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 23 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method $method
+done
+```
+
+### pruned_transducer_stateless7 (zipformer)
+
+See  for more details.
+
+[pruned_transducer_stateless7](./pruned_transducer_stateless7)
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+You can use  to deploy it.
+
+Number of model parameters: 70369391, i.e., 70.37 M
+
+|                      | test | dev  | comment                                |
+|----------------------|------|------|----------------------------------------|
+| greedy search        | 10.06| 10.59| --epoch 23 --avg 11 --max-duration 600 |
+| beam search          | 9.77 | 10.11| --epoch 23 --avg 11 --max-duration 600 |
+| modified beam search | 9.7  | 10.12| --epoch 23 --avg 11 --max-duration 600 |
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+
+./pruned_transducer_stateless7/train.py
+```
+
+The decoding commands are:
+```bash
+for m in greedy_search beam_search modified_beam_search; do
+  for epoch in 23; do
+    for avg in 11; do
+      ./pruned_transducer_stateless7/decode.py \
+          --epoch $epoch \
+          --avg $avg \
+          --exp-dir ./pruned_transducer_stateless7/exp \
+          --max-duration 600 \
+          --decoding-method $m
+    done
+  done
+done
+```
diff --git a/egs/xbmu_amdo31/ASR/local/compile_hlg.py b/egs/xbmu_amdo31/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compile_lg.py b/egs/xbmu_amdo31/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
new file mode 100755
index 000000000..a593e7be3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the XBMU-AMDO31 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import sentencepiece as spm
+import torch
+from filter_cuts import filter_cuts
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to the bpe.model. If not None, we will remove short and
+        long utterances before extracting features""",
+    )
+    return parser.parse_args()
+
+
+def compute_fbank_xbmu_amdo31(bpe_model: Optional[str] = None):
+    src_dir = Path("data/manifests")
+    output_dir = Path("data/fbank")
+    num_jobs = min(15, os.cpu_count())
+    num_mel_bins = 80
+
+    if bpe_model:
+        logging.info(f"Loading {bpe_model}")
+        sp = spm.SentencePieceProcessor()
+        sp.load(bpe_model)
+
+    dataset_parts = (
+        "train",
+        "dev",
+        "test",
+    )
+    prefix = "xbmu_amdo31"
+    suffix = "jsonl.gz"
+    manifests = read_manifests_if_cached(
+        dataset_parts=dataset_parts,
+        output_dir=src_dir,
+        prefix=prefix,
+        suffix=suffix,
+    )
+    assert manifests is not None
+
+    assert len(manifests) == len(dataset_parts), (
+        len(manifests),
+        len(dataset_parts),
+        list(manifests.keys()),
+        dataset_parts,
+    )
+
+    extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+    with get_executor() as ex:  # Initialize the executor only once.
+        for partition, m in manifests.items():
+            cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+            if (output_dir / cuts_filename).is_file():
+                logging.info(f"{partition} already exists - skipping.")
+                continue
+            logging.info(f"Processing {partition}")
+            cut_set = CutSet.from_manifests(
+                recordings=m["recordings"],
+                supervisions=m["supervisions"],
+            )
+            if bpe_model:
+                cut_set = filter_cuts(cut_set, sp)
+
+            if "train" in partition:
+                cut_set = (
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+                )
+            cut_set = cut_set.compute_and_store_features(
+                extractor=extractor,
+                storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+                # when an executor is specified, make more partitions
+                num_jobs=num_jobs if ex is None else 80,
+                executor=ex,
+                storage_type=LilcomChunkyWriter,
+            )
+            cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    args = get_args()
+    logging.info(vars(args))
+    compute_fbank_xbmu_amdo31(bpe_model=args.bpe_model)
diff --git a/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/filter_cuts.py b/egs/xbmu_amdo31/ASR/local/filter_cuts.py
new file mode 120000
index 000000000..27aca1729
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/filter_cuts.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/filter_cuts.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang.py b/egs/xbmu_amdo31/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
new file mode 120000
index 000000000..abc00d421
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lm_training_data.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
new file mode 120000
index 000000000..1d6ccbe33
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/sort_lm_training_data.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/train_bpe_model.py b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh
new file mode 100755
index 000000000..32ae440f7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/prepare.sh
@@ -0,0 +1,357 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+nj=15
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+#  - $dl_dir/xbmu_amdo31
+#      You can find data, resource, etc, inside it.
+#      You can download them from https://huggingface.co/datasets/syzym/xbmu_amdo31
+#
+#  - $dl_dir/lm
+#      This directory contains the following files downloaded from
+#       git lfs install
+#       https://huggingface.co/syzym/xbmu_amdo31_lm
+#
+#        - tibetan.3-gram.arpa
+#        - tibetan.4-gram.arpa
+#
+#  - $dl_dir/musan
+#      This directory contains the following directories downloaded from
+#       http://www.openslr.org/17/
+#
+#     - music
+#     - noise
+#     - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+  1000
+  500
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+  log "stage -1: Download LM"
+  # We assume that you have installed the git-lfs, if not, you could install it
+  # using: `sudo apt-get install git-lfs && git-lfs install`
+  git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+
+  if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
+    git clone https://huggingface.co/syzym/xbmu_amdo31_lm $dl_dir/lm
+    pushd $dl_dir/lm
+    git lfs pull --include "tibetan.3-gram.arpa"
+    git lfs pull --include "tibetan.4-gram.arpa"
+    popd
+  fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "Stage 0: Download data"
+
+  # If you have pre-downloaded it to /path/to/xbmu_amdo31,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/xbmu_amdo31 $dl_dir/xbmu_amdo31
+  #
+  
+  if [ ! -f $dl_dir/xbmu_amdo31 ]; then
+    git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+    lhotse download xbmu-amdo31 $dl_dir
+  fi
+
+  # If you have pre-downloaded it to /path/to/musan,
+  # you can create a symlink
+  #
+  #   ln -sfv /path/to/musan $dl_dir/
+  #
+  if [ ! -d $dl_dir/musan ]; then
+    lhotse download musan $dl_dir
+  fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "Stage 1: Prepare xbmu_amdo31 manifest"
+  # We assume that you have downloaded the xbmu_amdo31 corpus
+  # to $dl_dir/xbmu_amdo31
+  if [ ! -f data/manifests/.xbmu_amdo31_manifests.done ]; then
+    mkdir -p data/manifests
+    lhotse prepare xbmu-amdo31 $dl_dir/xbmu_amdo31 data/manifests
+    touch data/manifests/.xbmu_amdo31_manifests.done
+  fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "Stage 2: Prepare musan manifest"
+  # We assume that you have downloaded the musan corpus
+  # to data/musan
+  if [ ! -f data/manifests/.musan_manifests.done ]; then
+    log "It may take 6 minutes"
+    mkdir -p data/manifests
+    lhotse prepare musan $dl_dir/musan data/manifests
+    touch data/manifests/.musan_manifests.done
+  fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "Stage 3: Compute fbank for xbmu_amdo31"
+  if [ ! -f data/fbank/.xbmu_amdo31.done ]; then
+    mkdir -p data/fbank
+    ./local/compute_fbank_xbmu_amdo31.py
+    touch data/fbank/.xbmu_amdo31.done
+  fi
+fi
+
+
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "Stage 4: Compute fbank for musan"
+  if [ ! -f data/fbank/.msuan.done ]; then
+    mkdir -p data/fbank
+    ./local/compute_fbank_musan.py
+    touch data/fbank/.msuan.done
+  fi
+fi
+
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "Stage 5: Prepare phone based lang"
+  lang_dir=data/lang_phone
+  mkdir -p $lang_dir
+
+  (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+    cat - $dl_dir/xbmu_amdo31/resource/lexicon.txt |
+    sort | uniq > $lang_dir/lexicon.txt
+
+  ./local/generate_unique_lexicon.py --lang-dir $lang_dir
+
+  if [ ! -f $lang_dir/L_disambig.pt ]; then
+    ./local/prepare_lang.py --lang-dir $lang_dir
+  fi
+fi
+
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  log "Stage 6: Prepare BPE based lang"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    mkdir -p $lang_dir
+    # We reuse words.txt from phone based lexicon
+    # so that the two can share G.pt later.
+    cp data/lang_phone/words.txt $lang_dir
+
+  if [ ! -f $lang_dir/transcript_words.txt ]; then
+    log "Generate data to train phone based bigram P"
+    xbmu_amdo31_text=$dl_dir/xbmu_amdo31/data/transcript/transcript_clean.txt
+    xbmu_amdo31_train_uid=$dl_dir/xbmu_amdo31/data/transcript/xbmu_amdo31_train_uid
+    find $dl_dir/xbmu_amdo31/data/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '-' '{print $NF}' > $xbmu_amdo31_train_uid
+    awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $xbmu_amdo31_train_uid $xbmu_amdo31_text |
+	    cut -d " " -f 2- > $lang_dir/transcript_words.txt
+  fi
+
+    if [ ! -f $lang_dir/bpe.model ]; then
+      ./local/train_bpe_model.py \
+        --lang-dir $lang_dir \
+        --vocab-size $vocab_size \
+        --transcript $lang_dir/transcript_words.txt
+    fi
+
+    if [ ! -f $lang_dir/L_disambig.pt ]; then
+      ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+      log "Validating $lang_dir/lexicon.txt"
+      ./local/validate_bpe_lexicon.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --bpe-model $lang_dir/bpe.model
+    fi
+  done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+  log "Stage 7: Prepare bigram P"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+
+    if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+      ./local/convert_transcript_words_to_tokens.py \
+        --lexicon $lang_dir/lexicon.txt \
+        --transcript $lang_dir/transcript_words.txt \
+        --oov "" \
+        > $lang_dir/transcript_tokens.txt
+    fi
+
+    if [ ! -f $lang_dir/P.arpa ]; then
+      ./shared/make_kn_lm.py \
+        -ngram-order 2 \
+        -text $lang_dir/transcript_tokens.txt \
+        -lm $lang_dir/P.arpa
+    fi
+
+    if [ ! -f $lang_dir/P.fst.txt ]; then
+      python3 -m kaldilm \
+        --read-symbol-table="$lang_dir/tokens.txt" \
+        --disambig-symbol='#0' \
+        --max-order=2 \
+        $lang_dir/P.arpa > $lang_dir/P.fst.txt
+    fi
+  done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+  log "Stage 8: Prepare G"
+  # We assume you have install kaldilm, if not, please install
+  # it using: pip install kaldilm
+
+  mkdir -p data/lm
+  if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+    # It is used in building HLG
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang_phone/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=3 \
+      $dl_dir/lm/tibetan.3-gram.arpa > data/lm/G_3_gram.fst.txt
+  fi
+
+  if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+    # It is used for LM rescoring
+    python3 -m kaldilm \
+      --read-symbol-table="data/lang_phone/words.txt" \
+      --disambig-symbol='#0' \
+      --max-order=4 \
+      $dl_dir/lm/tibetan.4-gram.arpa > data/lm/G_4_gram.fst.txt
+  fi
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+  log "Stage 9: Compile HLG"
+  ./local/compile_hlg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_hlg.py --lang-dir $lang_dir
+  done
+fi
+
+# Compile LG for RNN-T fast_beam_search decoding
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+  log "Stage 10: Compile LG"
+  ./local/compile_lg.py --lang-dir data/lang_phone
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/compile_lg.py --lang-dir $lang_dir
+  done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+  log "Stage 11: Generate LM training data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    lang_dir=data/lang_bpe_${vocab_size}
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $dl_dir/lm/lm_train.txt \
+      --lm-archive $out_dir/lm_data.pt
+  done
+fi
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+  log "Stage 12: Generate LM validation data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    if [ ! -f $out_dir/valid.txt ]; then
+      files=$dl_dir/xbmu_amdo31/data/transcript/dev_text
+      for f in ${files[@]}; do
+        cat $f | cut -d " " -f 2-
+      done > $out_dir/valid.txt
+    fi
+
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $out_dir/valid.txt \
+      --lm-archive $out_dir/lm_data-valid.pt
+  done
+fi
+
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+  log "Stage 13: Generate LM test data"
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    log "Processing vocab_size == ${vocab_size}"
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+
+    if [ ! -f $out_dir/test.txt ]; then
+        files=$dl_dir/xbmu_amdo31/data/transcript/test_text
+        cat $f | cut -d " " -f 2- > $out_dir/test.txt
+    fi
+
+    lang_dir=data/lang_bpe_${vocab_size}
+    ./local/prepare_lm_training_data.py \
+      --bpe-model $lang_dir/bpe.model \
+      --lm-data $out_dir/test.txt \
+      --lm-archive $out_dir/lm_data-test.pt
+  done
+fi
+
+if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
+  log "Stage 14: Sort LM training data"
+  # Sort LM training data by sentence length in descending order
+  # for ease of training.
+  #
+  # Sentence length equals to the number of BPE tokens
+  # in a sentence.
+
+  for vocab_size in ${vocab_sizes[@]}; do
+    out_dir=data/lm_training_bpe_${vocab_size}
+    mkdir -p $out_dir
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data.pt \
+      --out-lm-data $out_dir/sorted_lm_data.pt \
+      --out-statistics $out_dir/statistics.txt
+
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data-valid.pt \
+      --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+      --out-statistics $out_dir/statistics-valid.txt
+
+    ./local/sort_lm_training_data.py \
+      --in-lm-data $out_dir/lm_data-test.pt \
+      --out-lm-data $out_dir/sorted_lm_data-test.pt \
+      --out-statistics $out_dir/statistics-test.txt
+  done
+fi
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 100644
index 000000000..55d5f4636
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1,408 @@
+# Copyright      2021  Piotr Żelasko
+# Copyright      2022  Xiaomi Corporation     (Author: Mingshuang Luo)
+# Copyright      2022  Northwest Minzu University     (Author: Senyan Li)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import CutConcatenate  # noqa F401 for PrecomputedFeatures
+from lhotse.dataset import (
+    CutMix,
+    DynamicBucketingSampler,
+    K2SpeechRecognitionDataset,
+    PrecomputedFeatures,
+    SingleCutSampler,
+    SpecAugment,
+)
+from lhotse.dataset.input_strategies import AudioSamples  # noqa F401 For AudioSamples
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+    def __init__(self, seed: int):
+        self.seed = seed
+
+    def __call__(self, worker_id: int):
+        fix_random_seed(self.seed + worker_id)
+
+
+class Xbmu_AmdoAsrDataModule:
+    """
+    DataModule for k2 ASR experiments.
+    It assumes there is always one train and valid dataloader,
+    but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+    and test-other).
+
+    It contains all the common data pipeline modules used in ASR
+    experiments, e.g.:
+    - dynamic batch size,
+    - bucketing samplers,
+    - cut concatenation,
+    - augmentation,
+    - on-the-fly feature extraction
+
+    This class should be derived for specific corpora used in ASR tasks.
+    """
+
+    def __init__(self, args: argparse.Namespace):
+        self.args = args
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(
+            title="ASR data related options",
+            description="These options are used for the preparation of "
+            "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+            "effective batch sizes, sampling strategies, applied data "
+            "augmentations, etc.",
+        )
+        group.add_argument(
+            "--manifest-dir",
+            type=Path,
+            default=Path("data/fbank"),
+            help="Path to directory with train/valid/test cuts.",
+        )
+        group.add_argument(
+            "--max-duration",
+            type=int,
+            default=200.0,
+            help="Maximum pooled recordings duration (seconds) in a "
+            "single batch. You can reduce it if it causes CUDA OOM.",
+        )
+        group.add_argument(
+            "--bucketing-sampler",
+            type=str2bool,
+            default=True,
+            help="When enabled, the batches will come from buckets of "
+            "similar duration (saves padding frames).",
+        )
+        group.add_argument(
+            "--num-buckets",
+            type=int,
+            default=30,
+            help="The number of buckets for the DynamicBucketingSampler"
+            "(you might want to increase it for larger datasets).",
+        )
+        group.add_argument(
+            "--concatenate-cuts",
+            type=str2bool,
+            default=False,
+            help="When enabled, utterances (cuts) will be concatenated "
+            "to minimize the amount of padding.",
+        )
+        group.add_argument(
+            "--duration-factor",
+            type=float,
+            default=1.0,
+            help="Determines the maximum duration of a concatenated cut "
+            "relative to the duration of the longest cut in a batch.",
+        )
+        group.add_argument(
+            "--gap",
+            type=float,
+            default=1.0,
+            help="The amount of padding (in seconds) inserted between "
+            "concatenated cuts. This padding is filled with noise when "
+            "noise augmentation is used.",
+        )
+        group.add_argument(
+            "--on-the-fly-feats",
+            type=str2bool,
+            default=False,
+            help="When enabled, use on-the-fly cut mixing and feature "
+            "extraction. Will drop existing precomputed feature manifests "
+            "if available.",
+        )
+        group.add_argument(
+            "--shuffle",
+            type=str2bool,
+            default=True,
+            help="When enabled (=default), the examples will be "
+            "shuffled for each epoch.",
+        )
+        group.add_argument(
+            "--drop-last",
+            type=str2bool,
+            default=True,
+            help="Whether to drop last batch. Used by sampler.",
+        )
+        group.add_argument(
+            "--return-cuts",
+            type=str2bool,
+            default=True,
+            help="When enabled, each batch will have the "
+            "field: batch['supervisions']['cut'] with the cuts that "
+            "were used to construct it.",
+        )
+
+        group.add_argument(
+            "--num-workers",
+            type=int,
+            default=2,
+            help="The number of training dataloader workers that "
+            "collect the batches.",
+        )
+
+        group.add_argument(
+            "--enable-spec-aug",
+            type=str2bool,
+            default=True,
+            help="When enabled, use SpecAugment for training dataset.",
+        )
+
+        group.add_argument(
+            "--spec-aug-time-warp-factor",
+            type=int,
+            default=80,
+            help="Used only when --enable-spec-aug is True. "
+            "It specifies the factor for time warping in SpecAugment. "
+            "Larger values mean more warping. "
+            "A value less than 1 means to disable time warp.",
+        )
+
+        group.add_argument(
+            "--enable-musan",
+            type=str2bool,
+            default=True,
+            help="When enabled, select noise from MUSAN and mix it"
+            "with training dataset. ",
+        )
+
+        group.add_argument(
+            "--input-strategy",
+            type=str,
+            default="PrecomputedFeatures",
+            help="AudioSamples or PrecomputedFeatures",
+        )
+
+    def train_dataloaders(
+        self,
+        cuts_train: CutSet,
+        sampler_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> DataLoader:
+        """
+        Args:
+          cuts_train:
+            CutSet for training.
+          sampler_state_dict:
+            The state dict for the training sampler.
+        """
+        transforms = []
+        if self.args.enable_musan:
+            logging.info("Enable MUSAN")
+            logging.info("About to get Musan cuts")
+            cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+            transforms.append(
+                CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+            )
+        else:
+            logging.info("Disable MUSAN")
+
+        if self.args.concatenate_cuts:
+            logging.info(
+                f"Using cut concatenation with duration factor "
+                f"{self.args.duration_factor} and gap {self.args.gap}."
+            )
+            # Cut concatenation should be the first transform in the list,
+            # so that if we e.g. mix noise in, it will fill the gaps between
+            # different utterances.
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        input_transforms = []
+        if self.args.enable_spec_aug:
+            logging.info("Enable SpecAugment")
+            logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+            # Set the value of num_frame_masks according to Lhotse's version.
+            # In different Lhotse's versions, the default of num_frame_masks is
+            # different.
+            num_frame_masks = 10
+            num_frame_masks_parameter = inspect.signature(
+                SpecAugment.__init__
+            ).parameters["num_frame_masks"]
+            if num_frame_masks_parameter.default == 1:
+                num_frame_masks = 2
+            logging.info(f"Num frame mask: {num_frame_masks}")
+            input_transforms.append(
+                SpecAugment(
+                    time_warp_factor=self.args.spec_aug_time_warp_factor,
+                    num_frame_masks=num_frame_masks,
+                    features_mask_size=27,
+                    num_feature_masks=2,
+                    frames_mask_size=100,
+                )
+            )
+        else:
+            logging.info("Disable SpecAugment")
+
+        logging.info("About to create train dataset")
+        train = K2SpeechRecognitionDataset(
+            input_strategy=eval(self.args.input_strategy)(),
+            cut_transforms=transforms,
+            input_transforms=input_transforms,
+            return_cuts=self.args.return_cuts,
+        )
+
+        if self.args.on_the_fly_feats:
+            # NOTE: the PerturbSpeed transform should be added only if we
+            # remove it from data prep stage.
+            # Add on-the-fly speed perturbation; since originally it would
+            # have increased epoch size by 3, we will apply prob 2/3 and use
+            # 3x more epochs.
+            # Speed perturbation probably should come first before
+            # concatenation, but in principle the transforms order doesn't have
+            # to be strict (e.g. could be randomized)
+            # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms   # noqa
+            # Drop feats to be on the safe side.
+            train = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                input_transforms=input_transforms,
+                return_cuts=self.args.return_cuts,
+            )
+
+        if self.args.bucketing_sampler:
+            logging.info("Using DynamicBucketingSampler.")
+            train_sampler = DynamicBucketingSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+                num_buckets=self.args.num_buckets,
+                drop_last=self.args.drop_last,
+            )
+        else:
+            logging.info("Using SingleCutSampler.")
+            train_sampler = SingleCutSampler(
+                cuts_train,
+                max_duration=self.args.max_duration,
+                shuffle=self.args.shuffle,
+            )
+        logging.info("About to create train dataloader")
+
+        if sampler_state_dict is not None:
+            logging.info("Loading sampler state dict")
+            train_sampler.load_state_dict(sampler_state_dict)
+
+        # 'seed' is derived from the current random state, which will have
+        # previously been set in the main process.
+        seed = torch.randint(0, 100000, ()).item()
+        worker_init_fn = _SeedWorkers(seed)
+
+        train_dl = DataLoader(
+            train,
+            sampler=train_sampler,
+            batch_size=None,
+            num_workers=self.args.num_workers,
+            persistent_workers=False,
+            worker_init_fn=worker_init_fn,
+        )
+
+        return train_dl
+
+    def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+        transforms = []
+        if self.args.concatenate_cuts:
+            transforms = [
+                CutConcatenate(
+                    duration_factor=self.args.duration_factor, gap=self.args.gap
+                )
+            ] + transforms
+
+        logging.info("About to create dev dataset")
+        if self.args.on_the_fly_feats:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+                return_cuts=self.args.return_cuts,
+            )
+        else:
+            validate = K2SpeechRecognitionDataset(
+                cut_transforms=transforms,
+                return_cuts=self.args.return_cuts,
+            )
+        valid_sampler = DynamicBucketingSampler(
+            cuts_valid,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.info("About to create dev dataloader")
+        valid_dl = DataLoader(
+            validate,
+            sampler=valid_sampler,
+            batch_size=None,
+            num_workers=2,
+            persistent_workers=False,
+        )
+
+        return valid_dl
+
+    def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+        logging.debug("About to create test dataset")
+        test = K2SpeechRecognitionDataset(
+            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+            if self.args.on_the_fly_feats
+            else eval(self.args.input_strategy)(),
+            return_cuts=self.args.return_cuts,
+        )
+        sampler = DynamicBucketingSampler(
+            cuts,
+            max_duration=self.args.max_duration,
+            shuffle=False,
+        )
+        logging.debug("About to create test dataloader")
+        test_dl = DataLoader(
+            test,
+            batch_size=None,
+            sampler=sampler,
+            num_workers=self.args.num_workers,
+        )
+        return test_dl
+
+    @lru_cache()
+    def train_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_train.jsonl.gz"
+        logging.info(f"About to get train cuts from {f}")
+        cuts_train = load_manifest_lazy(f)
+        return cuts_train
+
+    @lru_cache()
+    def valid_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_dev.jsonl.gz"
+        logging.info(f"About to get valid cuts from {f}")
+        cuts_valid = load_manifest_lazy(f)
+        return cuts_valid
+
+    @lru_cache()
+    def test_cuts(self) -> CutSet:
+        f = self.args.manifest_dir / "xbmu_amdo31_cuts_test.jsonl.gz"
+        logging.info(f"About to get test cuts from {f}")
+        cuts_test = load_manifest_lazy(f)
+        return cuts_test
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..b77f734e3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,964 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao,
+#                                                 Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+(4) fast beam search (one best)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+(5) fast beam search (nbest)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+(7) fast beam search (with LG)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(8) modified beam search with RNNLM shallow fusion (with LG)
+./pruned_transducer_stateless5/decode.py \
+    --epoch 35 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless5/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 4 \
+    --max-contexts 4 \
+    --rnn-lm-scale 0.4 \
+    --rnn-lm-exp-dir /path/to/RNNLM/exp \
+    --rnn-lm-epoch 99 \
+    --rnn-lm-avg 1 \
+    --rnn-lm-num-layers 3 \
+    --rnn-lm-tie-weights 1
+
+
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+    modified_beam_search_rnnlm_shallow_fusion,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.rnn_lm.model import RnnLmModel
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_LG
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+          - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is fast_beam_search_LG,
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is fast_beam_search_LG,
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-scale",
+        type=float,
+        default=0.0,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-exp-dir",
+        type=str,
+        default="rnn_lm/exp",
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the path to RNN LM exp dir.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-epoch",
+        type=int,
+        default=7,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the checkpoint to use.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-avg",
+        type=int,
+        default=2,
+        help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
+        It specifies the number of checkpoints to average.
+        """,
+    )
+
+    parser.add_argument(
+        "--rnn-lm-embedding-dim",
+        type=int,
+        default=2048,
+        help="Embedding dim of the model",
+    )
+
+    parser.add_argument(
+        "--rnn-lm-hidden-dim",
+        type=int,
+        default=2048,
+        help="Hidden dim of the model",
+    )
+
+    parser.add_argument(
+        "--rnn-lm-num-layers",
+        type=int,
+        default=4,
+        help="Number of RNN layers the model",
+    )
+    parser.add_argument(
+        "--rnn-lm-tie-weights",
+        type=str2bool,
+        default=False,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+    rnnlm: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if (
+        params.decoding_method == "fast_beam_search"
+        or params.decoding_method == "fast_beam_search_LG"
+    ):
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        if params.decoding_method == "fast_beam_search":
+            for hyp in sp.decode(hyp_tokens):
+                hyps.append(hyp.split())
+        else:
+            for hyp in hyp_tokens:
+                hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+            sp=sp,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+        if "LG" in params.decoding_method:
+            key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+    rnnlm: Optional[RnnLmModel] = None,
+    rnnlm_scale: float = 1.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+        logging.info(f"Decoding {batch_idx}-th batch")
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+            rnnlm=rnnlm,
+            rnnlm_scale=rnnlm_scale,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_LG",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+        "modified_beam_search_rnnlm_shallow_fusion",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+        if "LG" in params.decoding_method:
+            params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    rnn_lm_model = None
+    rnn_lm_scale = params.rnn_lm_scale
+    if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
+        rnn_lm_model = RnnLmModel(
+            vocab_size=params.vocab_size,
+            embedding_dim=params.rnn_lm_embedding_dim,
+            hidden_dim=params.rnn_lm_hidden_dim,
+            num_layers=params.rnn_lm_num_layers,
+            tie_weights=params.rnn_lm_tie_weights,
+        )
+        assert params.rnn_lm_avg == 1
+
+        load_checkpoint(
+            f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
+            rnn_lm_model,
+        )
+        rnn_lm_model.to(device)
+        rnn_lm_model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if "LG" in params.decoding_method:
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    test_cuts = xbmu_amdo.test_cuts()
+
+    test_dl = xbmu_amdo.test_dataloaders(test_cuts)
+
+    test_sets = ["test"]
+    test_dl = [test_dl]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+            rnnlm=rnn_lm_model,
+            rnnlm_scale=rnn_lm_scale,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
new file mode 120000
index 000000000..d59ef95f7
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..54f656859
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,287 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+  --exp-dir ./pruned_transducer_stateless5/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+    cd /path/to/exp_dir
+    ln -s pretrained.pt epoch-9999.pt
+
+    cd /path/to/egs/librispeech/ASR
+    ./pruned_transducer_stateless5/decode.py \
+        --exp-dir ./pruned_transducer_stateless5/exp \
+        --epoch 9999 \
+        --avg 1 \
+        --max-duration 600 \
+        --decoding-method greedy_search \
+        --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=28,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=15,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=False,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--streaming-model",
+        type=str2bool,
+        default=False,
+        help="""Whether to export a streaming model, if the models in exp-dir
+        are streaming model, this should be True.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.streaming_model:
+        assert params.causal_convolution
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        # We won't use the forward() method of the model in C++, so just ignore
+        # it here.
+        # Otherwise, one of its arguments is a ragged tensor and is not
+        # torch scriptabe.
+        convert_scaled_to_non_scaled(model, inplace=True)
+        model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..74a2210c3
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
new file mode 120000
index 000000000..1199a61d6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
new file mode 120000
index 000000000..f29284163
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
new file mode 100755
index 000000000..9aad32014
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+    cd icefall/egs/librispeech/ASR
+    python ./pruned_transducer_stateless4/test_model.py
+"""
+
+from train import get_params, get_transducer_model
+
+
+def test_model_1():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = 24
+    params.dim_feedforward = 1536  # 384 * 4
+    params.encoder_dim = 384
+    model = get_transducer_model(params)
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
+def test_model_M():
+    params = get_params()
+    params.vocab_size = 500
+    params.blank_id = 0
+    params.context_size = 2
+    params.num_encoder_layers = 18
+    params.dim_feedforward = 1024
+    params.encoder_dim = 256
+    params.nhead = 4
+    params.decoder_dim = 512
+    params.joiner_dim = 512
+    model = get_transducer_model(params)
+    num_param = sum([p.numel() for p in model.parameters()])
+    print(f"Number of model parameters: {num_param}")
+
+
+def main():
+    #  test_model_1()
+    test_model_M()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..5b5ac17be
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1187 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless5/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import (
+    AttributeDict,
+    MetricsTracker,
+    display_and_save_batch,
+    setup_logger,
+    str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=int,
+        default=24,
+        help="Number of conformer encoder layers..",
+    )
+
+    parser.add_argument(
+        "--dim-feedforward",
+        type=int,
+        default=1536,
+        help="Feedforward dimension of the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=384,
+        help="Attention dimension in the conformer encoder layer.",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+    parser.add_argument(
+        "--dynamic-chunk-training",
+        type=str2bool,
+        default=False,
+        help="""Whether to use dynamic_chunk_training, if you want a streaming
+        model, this requires to be True.
+        """,
+    )
+
+    parser.add_argument(
+        "--causal-convolution",
+        type=str2bool,
+        default=False,
+        help="""Whether to use causal convolution, this requires to be True when
+        using dynamic_chunk_training.
+        """,
+    )
+
+    parser.add_argument(
+        "--short-chunk-size",
+        type=int,
+        default=25,
+        help="""Chunk length of dynamic training, the chunk size would be either
+        max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+        """,
+    )
+
+    parser.add_argument(
+        "--num-left-chunks",
+        type=int,
+        default=4,
+        help="How many left context can be seen in chunks when calculating attention.",
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless5/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--initial-lr",
+        type=float,
+        default=0.003,
+        help="The initial learning rate.  This value should not need to be changed.",
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=6,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=4000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=100,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--delay-penalty",
+        type=float,
+        default=0.0,
+        help="""A constant value used to penalize symbol delay,
+        to encourage streaming models to emit symbols earlier.
+        See https://github.com/k2-fsa/k2/issues/955 and
+        https://arxiv.org/pdf/2211.00490.pdf for more details.""",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warm_step for Noam optimizer.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for conformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,
+            # parameters for Noam
+            "model_warm_step": 3000,  # arg given to model, not for lrate
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Conformer and Transformer
+    encoder = Conformer(
+        num_features=params.feature_dim,
+        subsampling_factor=params.subsampling_factor,
+        d_model=params.encoder_dim,
+        nhead=params.nhead,
+        dim_feedforward=params.dim_feedforward,
+        num_encoder_layers=params.num_encoder_layers,
+        dynamic_chunk_training=params.dynamic_chunk_training,
+        short_chunk_size=params.short_chunk_size,
+        num_left_chunks=params.num_left_chunks,
+        causal=params.causal_convolution,
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=params.encoder_dim,
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+    warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute RNN-T loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Conformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+            warmup=warmup,
+            reduction="none",
+            delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
+        )
+        simple_loss_is_finite = torch.isfinite(simple_loss)
+        pruned_loss_is_finite = torch.isfinite(pruned_loss)
+        is_finite = simple_loss_is_finite & pruned_loss_is_finite
+        if not torch.all(is_finite):
+            logging.info(
+                "Not all losses are finite!\n"
+                f"simple_loss: {simple_loss}\n"
+                f"pruned_loss: {pruned_loss}"
+            )
+            display_and_save_batch(batch, params=params, sp=sp)
+            simple_loss = simple_loss[simple_loss_is_finite]
+            pruned_loss = pruned_loss[pruned_loss_is_finite]
+
+            # If the batch contains more than 10 utterances AND
+            # if either all simple_loss or pruned_loss is inf or nan,
+            # we stop the training process by raising an exception
+            if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite):
+                raise ValueError(
+                    "There are too many utterances in this batch "
+                    "leading to inf or nan losses."
+                )
+
+        simple_loss = simple_loss.sum()
+        pruned_loss = pruned_loss.sum()
+        # after the main warmup step, we keep pruned_loss_scale small
+        # for the same amount of time (model_warm_step), to avoid
+        # overwhelming the simple_loss and causing it to diverge,
+        # in case it had not fully learned the alignment yet.
+        pruned_loss_scale = (
+            0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+        )
+        loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        # info["frames"] is an approximate number for two reasons:
+        # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+        # (2) If some utterances in the batch lead to inf/nan loss, they
+        #     are filtered out.
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances`  # noqa
+    info["utterances"] = feature.size(0)
+    # averaged input duration in frames over utterances
+    info["utt_duration"] = feature_lens.sum().item()
+    # averaged padding proportion over utterances
+    info["utt_pad_proportion"] = (
+        ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+    )
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=(params.batch_idx_train / params.model_warm_step),
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            scheduler.step_batch(params.batch_idx_train)
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}"
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.dynamic_chunk_training:
+        assert (
+            params.causal_convolution
+        ), "dynamic_chunk_training requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank])
+
+    optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    train_cuts = xbmu_amdo.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        if c.duration < 1.0 or c.duration > 20.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./conformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 1) // 2 - 1) // 2
+        tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = xbmu_amdo.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = xbmu_amdo.valid_cuts()
+    valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts)
+
+    if params.start_batch <= 0 and not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+            warmup=0.0 if params.start_epoch == 1 else 1.0,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+    warmup: float,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                    warmup=warmup,
+                )
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 120000
index 000000000..c473a600a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless5/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..e334e690a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,837 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+#                                                 Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method beam_search \
+    --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method modified_beam_search \
+    --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_oracle \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64 \
+    --num-paths 200 \
+    --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7/decode.py \
+    --epoch 28 \
+    --avg 15 \
+    --exp-dir ./pruned_transducer_stateless7/exp \
+    --max-duration 600 \
+    --decoding-method fast_beam_search_nbest_LG \
+    --beam 20.0 \
+    --max-contexts 8 \
+    --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from beam_search import (
+    beam_search,
+    fast_beam_search_nbest,
+    fast_beam_search_nbest_LG,
+    fast_beam_search_nbest_oracle,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+    AttributeDict,
+    setup_logger,
+    store_transcripts,
+    str2bool,
+    write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=30,
+        help="""It specifies the checkpoint to use for decoding.
+        Note: Epoch counts from 1.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=9,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=True,
+        help="Whether to load averaged model. Currently it only supports "
+        "using --epoch. If True, it would decode with the averaged model "
+        "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+        "Actually only the models with epoch number of `epoch-avg` and "
+        "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="The experiment dir",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--lang-dir",
+        type=Path,
+        default="data/lang_bpe_500",
+        help="The lang dir containing word table and LG graph",
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+          - fast_beam_search_nbest
+          - fast_beam_search_nbest_oracle
+          - fast_beam_search_nbest_LG
+        If you use fast_beam_search_nbest_LG, you have to specify
+        `--lang-dir`, which should contain `LG.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --decoding-method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=20.0,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --decoding-method is fast_beam_search,
+        fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle
+        """,
+    )
+
+    parser.add_argument(
+        "--ngram-lm-scale",
+        type=float,
+        default=0.01,
+        help="""
+        Used only when --decoding_method is fast_beam_search_nbest_LG.
+        It specifies the scale for n-gram LM scores.
+        """,
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=8,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=64,
+        help="""Used only when --decoding-method is
+        fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+        and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame.
+        Used only when --decoding_method is greedy_search""",
+    )
+
+    parser.add_argument(
+        "--num-paths",
+        type=int,
+        default=200,
+        help="""Number of paths for nbest decoding.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--nbest-scale",
+        type=float,
+        default=0.5,
+        help="""Scale applied to lattice scores when computing nbest paths.
+        Used only when the decoding method is fast_beam_search_nbest,
+        fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        type=str2bool,
+        default=False,
+        help="""Whether to simulate streaming in decoding, this is a good way to
+        test a streaming model.
+        """,
+    )
+
+    parser.add_argument(
+        "--decode-chunk-size",
+        type=int,
+        default=16,
+        help="The chunk size for decoding (in frames after subsampling)",
+    )
+
+    parser.add_argument(
+        "--left-context",
+        type=int,
+        default=64,
+        help="left context can be seen during decoding (in frames after subsampling)",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def decode_one_batch(
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+    """Decode one batch and return the result in a dict. The dict has the
+    following format:
+
+        - key: It indicates the setting used for decoding. For example,
+               if greedy_search is used, it would be "greedy_search"
+               If beam search with a beam size of 7 is used, it would be
+               "beam_7"
+        - value: It contains the decoding result. `len(value)` equals to
+                 batch size. `value[i]` is the decoding result for the i-th
+                 utterance in the given batch.
+    Args:
+      params:
+        It's the return value of :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      batch:
+        It is the return value from iterating
+        `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+        for the format of the `batch`.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return the decoding result. See above description for the format of
+      the returned dict.
+    """
+    device = next(model.parameters()).device
+    feature = batch["inputs"]
+    assert feature.ndim == 3
+
+    feature = feature.to(device)
+    # at entry, feature is (N, T, C)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    if params.simulate_streaming:
+        feature_lens += params.left_context
+        feature = torch.nn.functional.pad(
+            feature,
+            pad=(0, 0, 0, params.left_context),
+            value=LOG_EPS,
+        )
+        encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
+            x=feature,
+            x_lens=feature_lens,
+            chunk_size=params.decode_chunk_size,
+            left_context=params.left_context,
+            simulate_streaming=True,
+        )
+    else:
+        encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+    hyps = []
+
+    if params.decoding_method == "fast_beam_search":
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_LG":
+        hyp_tokens = fast_beam_search_nbest_LG(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in hyp_tokens:
+            hyps.append([word_table[i] for i in hyp])
+    elif params.decoding_method == "fast_beam_search_nbest":
+        hyp_tokens = fast_beam_search_nbest(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "fast_beam_search_nbest_oracle":
+        hyp_tokens = fast_beam_search_nbest_oracle(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+            num_paths=params.num_paths,
+            ref_texts=sp.encode(supervisions["text"]),
+            nbest_scale=params.nbest_scale,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.decoding_method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        batch_size = encoder_out.size(0)
+
+        for i in range(batch_size):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.decoding_method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.decoding_method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(
+                    f"Unsupported decoding method: {params.decoding_method}"
+                )
+            hyps.append(sp.decode(hyp).split())
+
+    if params.decoding_method == "greedy_search":
+        return {"greedy_search": hyps}
+    elif "fast_beam_search" in params.decoding_method:
+        key = f"beam_{params.beam}_"
+        key += f"max_contexts_{params.max_contexts}_"
+        key += f"max_states_{params.max_states}"
+        if "nbest" in params.decoding_method:
+            key += f"_num_paths_{params.num_paths}_"
+            key += f"nbest_scale_{params.nbest_scale}"
+            if "LG" in params.decoding_method:
+                key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+        return {key: hyps}
+    else:
+        return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+    dl: torch.utils.data.DataLoader,
+    params: AttributeDict,
+    model: nn.Module,
+    sp: spm.SentencePieceProcessor,
+    word_table: Optional[k2.SymbolTable] = None,
+    decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+    """Decode dataset.
+
+    Args:
+      dl:
+        PyTorch's dataloader containing the dataset to decode.
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The neural model.
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      decoding_graph:
+        The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+        only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+        fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+    Returns:
+      Return a dict, whose key may be "greedy_search" if greedy search
+      is used, or it may be "beam_7" if beam size of 7 is used.
+      Its value is a list of tuples. Each tuple contains two elements:
+      The first is the reference transcript, and the second is the
+      predicted result.
+    """
+    num_cuts = 0
+
+    try:
+        num_batches = len(dl)
+    except TypeError:
+        num_batches = "?"
+
+    if params.decoding_method == "greedy_search":
+        log_interval = 50
+    else:
+        log_interval = 20
+
+    results = defaultdict(list)
+    for batch_idx, batch in enumerate(dl):
+        texts = batch["supervisions"]["text"]
+        cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+        hyps_dict = decode_one_batch(
+            params=params,
+            model=model,
+            sp=sp,
+            decoding_graph=decoding_graph,
+            word_table=word_table,
+            batch=batch,
+        )
+
+        for name, hyps in hyps_dict.items():
+            this_batch = []
+            assert len(hyps) == len(texts)
+            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+                ref_words = ref_text.split()
+                this_batch.append((cut_id, ref_words, hyp_words))
+
+            results[name].extend(this_batch)
+
+        num_cuts += len(texts)
+
+        if batch_idx % log_interval == 0:
+            batch_str = f"{batch_idx}/{num_batches}"
+
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+    return results
+
+
+def save_results(
+    params: AttributeDict,
+    test_set_name: str,
+    results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+    test_set_wers = dict()
+    for key, results in results_dict.items():
+        recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+        results = sorted(results)
+        store_transcripts(filename=recog_path, texts=results)
+        logging.info(f"The transcripts are stored in {recog_path}")
+
+        # The following prints out WERs, per-word error statistics and aligned
+        # ref/hyp pairs.
+        errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+        with open(errs_filename, "w") as f:
+            wer = write_error_stats(
+                f, f"{test_set_name}-{key}", results, enable_log=True
+            )
+            test_set_wers[key] = wer
+
+        logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+    errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+    with open(errs_info, "w") as f:
+        print("settings\tWER", file=f)
+        for key, val in test_set_wers:
+            print("{}\t{}".format(key, val), file=f)
+
+    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+    note = "\tbest for {}".format(test_set_name)
+    for key, val in test_set_wers:
+        s += "{}\t{}{}\n".format(key, val, note)
+        note = ""
+    logging.info(s)
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    assert params.decoding_method in (
+        "greedy_search",
+        "beam_search",
+        "fast_beam_search",
+        "fast_beam_search_nbest",
+        "fast_beam_search_nbest_LG",
+        "fast_beam_search_nbest_oracle",
+        "modified_beam_search",
+    )
+    params.res_dir = params.exp_dir / params.decoding_method
+
+    if params.iter > 0:
+        params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+    else:
+        params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+    if params.simulate_streaming:
+        params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
+        params.suffix += f"-left-context-{params.left_context}"
+
+    if "fast_beam_search" in params.decoding_method:
+        params.suffix += f"-beam-{params.beam}"
+        params.suffix += f"-max-contexts-{params.max_contexts}"
+        params.suffix += f"-max-states-{params.max_states}"
+        if "nbest" in params.decoding_method:
+            params.suffix += f"-nbest-scale-{params.nbest_scale}"
+            params.suffix += f"-num-paths-{params.num_paths}"
+            if "LG" in params.decoding_method:
+                params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+    elif "beam_search" in params.decoding_method:
+        params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+    else:
+        params.suffix += f"-context-{params.context_size}"
+        params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+    if params.use_averaged_model:
+        params.suffix += "-use-averaged-model"
+
+    setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+    logging.info("Decoding started")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  and  are defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    if params.simulate_streaming:
+        assert (
+            params.causal_convolution
+        ), "Decoding in streaming requires causal convolution"
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                : params.avg + 1
+            ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to(device)
+    model.eval()
+
+    if "fast_beam_search" in params.decoding_method:
+        if params.decoding_method == "fast_beam_search_nbest_LG":
+            lexicon = Lexicon(params.lang_dir)
+            word_table = lexicon.word_table
+            lg_filename = params.lang_dir / "LG.pt"
+            logging.info(f"Loading {lg_filename}")
+            decoding_graph = k2.Fsa.from_dict(
+                torch.load(lg_filename, map_location=device)
+            )
+            decoding_graph.scores *= params.ngram_lm_scale
+        else:
+            word_table = None
+            decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+    else:
+        decoding_graph = None
+        word_table = None
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    # we need cut ids to display recognition results.
+    args.return_cuts = True
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    test_cuts = xbmu_amdo.test_cuts()
+
+    test_dl = xbmu_amdo.test_dataloaders(test_cuts)
+
+    test_sets = [
+        "test",
+    ]
+    test_dl = [
+        test_dl,
+    ]
+
+    for test_set, test_dl in zip(test_sets, test_dl):
+        results_dict = decode_dataset(
+            dl=test_dl,
+            params=params,
+            model=model,
+            sp=sp,
+            word_table=word_table,
+            decoding_graph=decoding_graph,
+        )
+
+        save_results(
+            params=params,
+            test_set_name=test_set,
+            results_dict=results_dict,
+        )
+
+    logging.info("Done!")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
new file mode 120000
index 000000000..2713792e6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/export.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
new file mode 120000
index 000000000..a44034e34
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
new file mode 100755
index 000000000..d05bafcfb
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1,355 @@
+#!/usr/bin/env python3
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7/export.py \
+  --exp-dir ./pruned_transducer_stateless7/exp \
+  --bpe-model data/lang_bpe_500/bpe.model \
+  --epoch 20 \
+  --avg 10
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method greedy_search \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method modified_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless7/pretrained.py \
+    --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+    --bpe-model ./data/lang_bpe_500/bpe.model \
+    --method fast_beam_search \
+    --beam-size 4 \
+    /path/to/foo.wav \
+    /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+    beam_search,
+    fast_beam_search_one_best,
+    greedy_search,
+    greedy_search_batch,
+    modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.utils import str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--checkpoint",
+        type=str,
+        required=True,
+        help="Path to the checkpoint. "
+        "The checkpoint is assumed to be saved by "
+        "icefall.checkpoint.save_checkpoint().",
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        help="""Path to bpe.model.""",
+    )
+
+    parser.add_argument(
+        "--method",
+        type=str,
+        default="greedy_search",
+        help="""Possible values are:
+          - greedy_search
+          - beam_search
+          - modified_beam_search
+          - fast_beam_search
+        """,
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to transcribe. "
+        "Supported formats are those supported by torchaudio.load(). "
+        "For example, wav and flac are supported. "
+        "The sample rate has to be 16kHz.",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="The sample rate of the input sound file",
+    )
+
+    parser.add_argument(
+        "--beam-size",
+        type=int,
+        default=4,
+        help="""An integer indicating how many candidates we will keep for each
+        frame. Used only when --method is beam_search or
+        modified_beam_search.""",
+    )
+
+    parser.add_argument(
+        "--beam",
+        type=float,
+        default=4,
+        help="""A floating point value to calculate the cutoff score during beam
+        search (i.e., `cutoff = max-score - beam`), which is the same as the
+        `beam` in Kaldi.
+        Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-contexts",
+        type=int,
+        default=4,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--max-states",
+        type=int,
+        default=8,
+        help="""Used only when --method is fast_beam_search""",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+    parser.add_argument(
+        "--max-sym-per-frame",
+        type=int,
+        default=1,
+        help="""Maximum number of symbols per frame. Used only when
+        --method is greedy_search.
+        """,
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def read_sound_files(
+    filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+    """Read a list of sound files into a list 1-D float32 torch tensors.
+    Args:
+      filenames:
+        A list of sound filenames.
+      expected_sample_rate:
+        The expected sample rate of the sound files.
+    Returns:
+      Return a list of 1-D float32 torch tensors.
+    """
+    ans = []
+    for f in filenames:
+        wave, sample_rate = torchaudio.load(f)
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+        # We use only the first channel
+        ans.append(wave[0])
+    return ans
+
+
+@torch.no_grad()
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    params = get_params()
+
+    params.update(vars(args))
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.unk_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(f"{params}")
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("Creating model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoint = torch.load(args.checkpoint, map_location="cpu")
+    model.load_state_dict(checkpoint["model"], strict=False)
+    model.to(device)
+    model.eval()
+    model.device = device
+
+    logging.info("Constructing Fbank computer")
+    opts = kaldifeat.FbankOptions()
+    opts.device = device
+    opts.frame_opts.dither = 0
+    opts.frame_opts.snip_edges = False
+    opts.frame_opts.samp_freq = params.sample_rate
+    opts.mel_opts.num_bins = params.feature_dim
+
+    fbank = kaldifeat.Fbank(opts)
+
+    logging.info(f"Reading sound files: {params.sound_files}")
+    waves = read_sound_files(
+        filenames=params.sound_files, expected_sample_rate=params.sample_rate
+    )
+    waves = [w.to(device) for w in waves]
+
+    logging.info("Decoding started")
+    features = fbank(waves)
+    feature_lengths = [f.size(0) for f in features]
+
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+    feature_lengths = torch.tensor(feature_lengths, device=device)
+
+    encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+    num_waves = encoder_out.size(0)
+    hyps = []
+    msg = f"Using {params.method}"
+    if params.method == "beam_search":
+        msg += f" with beam size {params.beam_size}"
+    logging.info(msg)
+
+    if params.method == "fast_beam_search":
+        decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+        hyp_tokens = fast_beam_search_one_best(
+            model=model,
+            decoding_graph=decoding_graph,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam,
+            max_contexts=params.max_contexts,
+            max_states=params.max_states,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "modified_beam_search":
+        hyp_tokens = modified_beam_search(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+            beam=params.beam_size,
+        )
+
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+        hyp_tokens = greedy_search_batch(
+            model=model,
+            encoder_out=encoder_out,
+            encoder_out_lens=encoder_out_lens,
+        )
+        for hyp in sp.decode(hyp_tokens):
+            hyps.append(hyp.split())
+    else:
+        for i in range(num_waves):
+            # fmt: off
+            encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+            # fmt: on
+            if params.method == "greedy_search":
+                hyp = greedy_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    max_sym_per_frame=params.max_sym_per_frame,
+                )
+            elif params.method == "beam_search":
+                hyp = beam_search(
+                    model=model,
+                    encoder_out=encoder_out_i,
+                    beam=params.beam_size,
+                )
+            else:
+                raise ValueError(f"Unsupported method: {params.method}")
+
+            hyps.append(sp.decode(hyp).split())
+
+    s = "\n"
+    for filename, hyp in zip(params.sound_files, hyps):
+        words = " ".join(hyp)
+        s += f"{filename}:\n{words}\n\n"
+    logging.info(s)
+
+    logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..1332bafd8
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1224 @@
+#!/usr/bin/env python3
+# Copyright    2021-2022  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                       Wei Kang,
+#                                                       Mingshuang Luo,)
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --full-libri 1 \
+  --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7/train.py \
+  --world-size 4 \
+  --num-epochs 30 \
+  --start-epoch 1 \
+  --use-fp16 1 \
+  --exp-dir pruned_transducer_stateless7/exp \
+  --full-libri 1 \
+  --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Xbmu_AmdoAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+    save_checkpoint_with_global_batch_idx,
+    update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+    if isinstance(model, DDP):
+        # get underlying nn.Module
+        model = model.module
+    for module in model.modules():
+        if hasattr(module, "batch_count"):
+            module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+    parser.add_argument(
+        "--num-encoder-layers",
+        type=str,
+        default="2,4,3,2,4",
+        help="Number of zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--feedforward-dims",
+        type=str,
+        default="1024,1024,2048,2048,1024",
+        help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=str,
+        default="8,8,8,8,8",
+        help="Number of attention heads in the zipformer encoder layers.",
+    )
+
+    parser.add_argument(
+        "--encoder-dims",
+        type=str,
+        default="384,384,384,384,384",
+        help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+    )
+
+    parser.add_argument(
+        "--attention-dims",
+        type=str,
+        default="192,192,192,192,192",
+        help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+        not the same as embedding dimension.""",
+    )
+
+    parser.add_argument(
+        "--encoder-unmasked-dims",
+        type=str,
+        default="256,256,256,256,256",
+        help="Unmasked dimensions in the encoders, relates to augmentation during training.  "
+        "Must be <= each of encoder_dims.  Empirically, less than 256 seems to make performance "
+        " worse.",
+    )
+
+    parser.add_argument(
+        "--zipformer-downsampling-factors",
+        type=str,
+        default="1,2,4,8,2",
+        help="Downsampling factor for each stack of encoder layers.",
+    )
+
+    parser.add_argument(
+        "--cnn-module-kernels",
+        type=str,
+        default="31,31,31,31,31",
+        help="Sizes of kernels in convolution modules",
+    )
+
+    parser.add_argument(
+        "--decoder-dim",
+        type=int,
+        default=512,
+        help="Embedding dimension in the decoder model.",
+    )
+
+    parser.add_argument(
+        "--joiner-dim",
+        type=int,
+        default=512,
+        help="""Dimension used in the joiner model.
+        Outputs from the encoder and decoder model are projected
+        to this dimension before adding.
+        """,
+    )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=1,
+        help="""Resume training from this epoch. It should be positive.
+        If larger than 1, it will load checkpoint from
+        exp-dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--start-batch",
+        type=int,
+        default=0,
+        help="""If positive, --start-epoch is ignored and
+        it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="pruned_transducer_stateless7/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-model",
+        type=str,
+        default="data/lang_bpe_500/bpe.model",
+        help="Path to the BPE model",
+    )
+
+    parser.add_argument(
+        "--base-lr", type=float, default=0.05, help="The base learning rate."
+    )
+
+    parser.add_argument(
+        "--lr-batches",
+        type=float,
+        default=5000,
+        help="""Number of steps that affects how rapidly the learning rate
+        decreases. We suggest not to change this.""",
+    )
+
+    parser.add_argument(
+        "--lr-epochs",
+        type=float,
+        default=3.5,
+        help="""Number of epochs that affects how rapidly the learning rate decreases.
+        """,
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    parser.add_argument(
+        "--prune-range",
+        type=int,
+        default=5,
+        help="The prune range for rnnt loss, it means how many symbols(context)"
+        "we are using to compute the loss",
+    )
+
+    parser.add_argument(
+        "--lm-scale",
+        type=float,
+        default=0.25,
+        help="The scale to smooth the loss with lm "
+        "(output of prediction network) part.",
+    )
+
+    parser.add_argument(
+        "--am-scale",
+        type=float,
+        default=0.0,
+        help="The scale to smooth the loss with am (output of encoder network) part.",
+    )
+
+    parser.add_argument(
+        "--simple-loss-scale",
+        type=float,
+        default=0.5,
+        help="To get pruning ranges, we will calculate a simple version"
+        "loss(joiner is just addition), this simple loss also uses for"
+        "training (as a regularization item). We will scale the simple loss"
+        "with this parameter before adding to the final loss.",
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    parser.add_argument(
+        "--print-diagnostics",
+        type=str2bool,
+        default=False,
+        help="Accumulate stats on activations, print them and exit.",
+    )
+
+    parser.add_argument(
+        "--inf-check",
+        type=str2bool,
+        default=False,
+        help="Add hooks to check for infinite module outputs and gradients.",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
+    parser.add_argument(
+        "--keep-last-k",
+        type=int,
+        default=30,
+        help="""Only keep this number of checkpoints on disk.
+        For instance, if it is 3, there are only 3 checkpoints
+        in the exp-dir with filenames `checkpoint-xxx.pt`.
+        It does not affect checkpoints with name `epoch-xxx.pt`.
+        """,
+    )
+
+    parser.add_argument(
+        "--average-period",
+        type=int,
+        default=200,
+        help="""Update the averaged model, namely `model_avg`, after processing
+        this number of batches. `model_avg` is a separate version of model,
+        in which each floating-point parameter is the average of all the
+        parameters from the start of training. Each time we take the average,
+        we do: `model_avg = model * (average_period / batch_idx_train) +
+            model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=False,
+        help="Whether to use half precision training.",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters.
+
+    All training related parameters that are not passed from the commandline
+    are saved in the variable `params`.
+
+    Commandline options are merged into `params` after they are parsed, so
+    you can also access them via `params`.
+
+    Explanation of options saved in `params`:
+
+        - best_train_loss: Best training loss so far. It is used to select
+                           the model that has the lowest training loss. It is
+                           updated during the training.
+
+        - best_valid_loss: Best validation loss so far. It is used to select
+                           the model that has the lowest validation loss. It is
+                           updated during the training.
+
+        - best_train_epoch: It is the epoch that has the best training loss.
+
+        - best_valid_epoch: It is the epoch that has the best validation loss.
+
+        - batch_idx_train: Used to writing statistics to tensorboard. It
+                           contains number of batches trained so far across
+                           epochs.
+
+        - log_interval:  Print training loss if batch_idx % log_interval` is 0
+
+        - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+        - valid_interval:  Run validation if batch_idx % valid_interval is 0
+
+        - feature_dim: The model input dim. It has to match the one used
+                       in computing features.
+
+        - subsampling_factor:  The subsampling factor for the model.
+
+        - encoder_dim: Hidden dim for multi-head attention model.
+
+        - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+        - warm_step: The warmup period that dictates the decay of the
+              scale on "simple" (un-pruned) loss.
+    """
+    params = AttributeDict(
+        {
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 50,
+            "reset_interval": 200,
+            "valid_interval": 3000,  # For the 100h subset, use 800
+            # parameters for zipformer
+            "feature_dim": 80,
+            "subsampling_factor": 4,  # not passed in, this is fixed.
+            "warm_step": 2000,
+            "env_info": get_env_info(),
+        }
+    )
+
+    return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+    # TODO: We can add an option to switch between Zipformer and Transformer
+    def to_int_tuple(s: str):
+        return tuple(map(int, s.split(",")))
+
+    encoder = Zipformer(
+        num_features=params.feature_dim,
+        output_downsampling_factor=2,
+        zipformer_downsampling_factors=to_int_tuple(
+            params.zipformer_downsampling_factors
+        ),
+        encoder_dims=to_int_tuple(params.encoder_dims),
+        attention_dim=to_int_tuple(params.attention_dims),
+        encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+        nhead=to_int_tuple(params.nhead),
+        feedforward_dim=to_int_tuple(params.feedforward_dims),
+        cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+        num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+    )
+    return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+    decoder = Decoder(
+        vocab_size=params.vocab_size,
+        decoder_dim=params.decoder_dim,
+        blank_id=params.blank_id,
+        context_size=params.context_size,
+    )
+    return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+    joiner = Joiner(
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+    encoder = get_encoder_model(params)
+    decoder = get_decoder_model(params)
+    joiner = get_joiner_model(params)
+
+    model = Transducer(
+        encoder=encoder,
+        decoder=decoder,
+        joiner=joiner,
+        encoder_dim=int(params.encoder_dims.split(",")[-1]),
+        decoder_dim=params.decoder_dim,
+        joiner_dim=params.joiner_dim,
+        vocab_size=params.vocab_size,
+    )
+    return model
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    model_avg: nn.Module = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+    """Load checkpoint from file.
+
+    If params.start_batch is positive, it will load the checkpoint from
+    `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+    params.start_epoch is larger than 1, it will load the checkpoint from
+    `params.start_epoch - 1`.
+
+    Apart from loading state dict for `model` and `optimizer` it also updates
+    `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The scheduler that we are using.
+    Returns:
+      Return a dict containing previously saved training info.
+    """
+    if params.start_batch > 0:
+        filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+    elif params.start_epoch > 1:
+        filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    else:
+        return None
+
+    assert filename.is_file(), f"{filename} does not exist!"
+
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        model_avg=model_avg,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    if params.start_batch > 0:
+        if "cur_epoch" in saved_params:
+            params["start_epoch"] = saved_params["cur_epoch"]
+
+        if "cur_batch_idx" in saved_params:
+            params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    model_avg: Optional[nn.Module] = None,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[LRSchedulerType] = None,
+    sampler: Optional[CutSampler] = None,
+    scaler: Optional[GradScaler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+      model_avg:
+        The stored model averaged from the start of training.
+      optimizer:
+        The optimizer used in the training.
+      sampler:
+       The sampler for the training dataset.
+      scaler:
+        The scaler used for mix precision training.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        model_avg=model_avg,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        sampler=sampler,
+        scaler=scaler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    batch: dict,
+    is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+    """
+    Compute transducer loss given the model and its inputs.
+
+    Args:
+      params:
+        Parameters for training. See :func:`get_params`.
+      model:
+        The model for training. It is an instance of Zipformer in our case.
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      is_training:
+        True for training. False for validation. When it is True, this
+        function enables autograd during computation; when it is False, it
+        disables autograd.
+     warmup: a floating point value which increases throughout training;
+        values >= 1.0 are fully warmed up and have all modules present.
+    """
+    device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+    feature = batch["inputs"]
+    # at entry, feature is (N, T, C)
+    assert feature.ndim == 3
+    feature = feature.to(device)
+
+    supervisions = batch["supervisions"]
+    feature_lens = supervisions["num_frames"].to(device)
+
+    batch_idx_train = params.batch_idx_train
+    warm_step = params.warm_step
+
+    texts = batch["supervisions"]["text"]
+    y = sp.encode(texts, out_type=int)
+    y = k2.RaggedTensor(y).to(device)
+
+    with torch.set_grad_enabled(is_training):
+        simple_loss, pruned_loss = model(
+            x=feature,
+            x_lens=feature_lens,
+            y=y,
+            prune_range=params.prune_range,
+            am_scale=params.am_scale,
+            lm_scale=params.lm_scale,
+        )
+
+        s = params.simple_loss_scale
+        # take down the scale on the simple loss from 1.0 at the start
+        # to params.simple_loss scale by warm_step.
+        simple_loss_scale = (
+            s
+            if batch_idx_train >= warm_step
+            else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+        )
+        pruned_loss_scale = (
+            1.0
+            if batch_idx_train >= warm_step
+            else 0.1 + 0.9 * (batch_idx_train / warm_step)
+        )
+
+        loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+    assert loss.requires_grad == is_training
+
+    info = MetricsTracker()
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+    # Note: We use reduction=sum while computing the loss.
+    info["loss"] = loss.detach().cpu().item()
+    info["simple_loss"] = simple_loss.detach().cpu().item()
+    info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+    return loss, info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    sp: spm.SentencePieceProcessor,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process."""
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        loss, loss_info = compute_loss(
+            params=params,
+            model=model,
+            sp=sp,
+            batch=batch,
+            is_training=False,
+        )
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: Union[nn.Module, DDP],
+    optimizer: torch.optim.Optimizer,
+    scheduler: LRSchedulerType,
+    sp: spm.SentencePieceProcessor,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    scaler: GradScaler,
+    model_avg: Optional[nn.Module] = None,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+    rank: int = 0,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all frames is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      scheduler:
+        The learning rate scheduler, we call step() every step.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      scaler:
+        The scaler used for mix precision training.
+      model_avg:
+        The stored model averaged from the start of training.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+      rank:
+        The rank of the node in DDP training. If no DDP is used, it should
+        be set to 0.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    cur_batch_idx = params.get("cur_batch_idx", 0)
+
+    for batch_idx, batch in enumerate(train_dl):
+        if batch_idx < cur_batch_idx:
+            continue
+        cur_batch_idx = batch_idx
+
+        params.batch_idx_train += 1
+        batch_size = len(batch["supervisions"]["text"])
+
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, loss_info = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            # summary stats
+            tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+            # NOTE: We use reduction==sum and loss is computed over utterances
+            # in the batch and there is no normalization to it so far.
+            scaler.scale(loss).backward()
+            set_batch_count(model, params.batch_idx_train)
+            scheduler.step_batch(params.batch_idx_train)
+
+            scaler.step(optimizer)
+            scaler.update()
+            optimizer.zero_grad()
+        except:  # noqa
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+
+        if params.print_diagnostics and batch_idx == 5:
+            return
+
+        if (
+            rank == 0
+            and params.batch_idx_train > 0
+            and params.batch_idx_train % params.average_period == 0
+        ):
+            update_averaged_model(
+                params=params,
+                model_cur=model,
+                model_avg=model_avg,
+            )
+
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            params.cur_batch_idx = batch_idx
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                model_avg=model_avg,
+                params=params,
+                optimizer=optimizer,
+                scheduler=scheduler,
+                sampler=train_dl.sampler,
+                scaler=scaler,
+                rank=rank,
+            )
+            del params.cur_batch_idx
+            remove_checkpoints(
+                out_dir=params.exp_dir,
+                topk=params.keep_last_k,
+                rank=rank,
+            )
+
+        if batch_idx % 100 == 0 and params.use_fp16:
+            # If the grad scale was less than 1, try increasing it.    The _growth_interval
+            # of the grad scaler is configurable, but we can't configure it to have different
+            # behavior depending on the current grad scale.
+            cur_grad_scale = scaler._scale.item()
+            if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+                scaler.update(cur_grad_scale * 2.0)
+            if cur_grad_scale < 0.01:
+                logging.warning(f"Grad scale is small: {cur_grad_scale}")
+            if cur_grad_scale < 1.0e-05:
+                raise RuntimeError(
+                    f"grad_scale is too small, exiting: {cur_grad_scale}"
+                )
+
+        if batch_idx % params.log_interval == 0:
+            cur_lr = scheduler.get_last_lr()[0]
+            cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}], "
+                f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+                f"lr: {cur_lr:.2e}, "
+                + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+            )
+
+            if tb_writer is not None:
+                tb_writer.add_scalar(
+                    "train/learning_rate", cur_lr, params.batch_idx_train
+                )
+
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+                if params.use_fp16:
+                    tb_writer.add_scalar(
+                        "train/grad_scale",
+                        cur_grad_scale,
+                        params.batch_idx_train,
+                    )
+
+        if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+            logging.info("Computing validation loss")
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                sp=sp,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+            logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+            logging.info(
+                f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+            )
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+
+    fix_random_seed(params.seed)
+    if world_size > 1:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+    logging.info(f"Device: {device}")
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(params.bpe_model)
+
+    #  is defined in local/train_bpe_model.py
+    params.blank_id = sp.piece_to_id("")
+    params.vocab_size = sp.get_piece_size()
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_transducer_model(params)
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    assert params.save_every_n >= params.average_period
+    model_avg: Optional[nn.Module] = None
+    if rank == 0:
+        # model_avg is only used with rank 0
+        model_avg = copy.deepcopy(model).to(torch.float64)
+
+    assert params.start_epoch > 0, params.start_epoch
+    checkpoints = load_checkpoint_if_available(
+        params=params, model=model, model_avg=model_avg
+    )
+
+    model.to(device)
+    if world_size > 1:
+        logging.info("Using DDP")
+        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+    optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0)
+
+    scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+    if checkpoints and "optimizer" in checkpoints:
+        logging.info("Loading optimizer state dict")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    if (
+        checkpoints
+        and "scheduler" in checkpoints
+        and checkpoints["scheduler"] is not None
+    ):
+        logging.info("Loading scheduler state dict")
+        scheduler.load_state_dict(checkpoints["scheduler"])
+
+    if params.print_diagnostics:
+        opts = diagnostics.TensorDiagnosticOptions(
+            2**22
+        )  # allow 4 megabytes per sub-module
+        diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+    if params.inf_check:
+        register_inf_check_hooks(model)
+
+    xbmu_amdo = Xbmu_AmdoAsrDataModule(args)
+
+    train_cuts = xbmu_amdo.train_cuts()
+
+    def remove_short_and_long_utt(c: Cut):
+        # Keep only utterances with duration between 1 second and 20 seconds
+        #
+        # Caution: There is a reason to select 20.0 here. Please see
+        # ../local/display_manifest_statistics.py
+        #
+        # You should use ../local/display_manifest_statistics.py to get
+        # an utterance duration distribution for your dataset to select
+        # the threshold
+        if c.duration < 1.0 or c.duration > 20.0:
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+            )
+            return False
+
+        # In pruned RNN-T, we require that T >= S
+        # where T is the number of feature frames after subsampling
+        # and S is the number of tokens in the utterance
+
+        # In ./zipformer.py, the conv module uses the following expression
+        # for subsampling
+        T = ((c.num_frames - 7) // 2 + 1) // 2
+        tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+        if T < len(tokens):
+            logging.warning(
+                f"Exclude cut with ID {c.id} from training. "
+                f"Number of frames (before subsampling): {c.num_frames}. "
+                f"Number of frames (after subsampling): {T}. "
+                f"Text: {c.supervisions[0].text}. "
+                f"Tokens: {tokens}. "
+                f"Number of tokens: {len(tokens)}"
+            )
+            return False
+
+        return True
+
+    train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+    if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+        # We only load the sampler's state dict when it loads a checkpoint
+        # saved in the middle of an epoch
+        sampler_state_dict = checkpoints["sampler"]
+    else:
+        sampler_state_dict = None
+
+    train_dl = xbmu_amdo.train_dataloaders(
+        train_cuts, sampler_state_dict=sampler_state_dict
+    )
+
+    valid_cuts = xbmu_amdo.valid_cuts()
+    valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts)
+
+    if not params.print_diagnostics:
+        scan_pessimistic_batches_for_oom(
+            model=model,
+            train_dl=train_dl,
+            optimizer=optimizer,
+            sp=sp,
+            params=params,
+        )
+
+    scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+    if checkpoints and "grad_scaler" in checkpoints:
+        logging.info("Loading grad scaler state dict")
+        scaler.load_state_dict(checkpoints["grad_scaler"])
+
+    for epoch in range(params.start_epoch, params.num_epochs + 1):
+        scheduler.step_epoch(epoch - 1)
+        fix_random_seed(params.seed + epoch - 1)
+        train_dl.sampler.set_epoch(epoch - 1)
+
+        if tb_writer is not None:
+            tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sp=sp,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            scaler=scaler,
+            tb_writer=tb_writer,
+            world_size=world_size,
+            rank=rank,
+        )
+
+        if params.print_diagnostics:
+            diagnostic.print_diagnostics()
+            break
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            model_avg=model_avg,
+            optimizer=optimizer,
+            scheduler=scheduler,
+            sampler=train_dl.sampler,
+            scaler=scaler,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if world_size > 1:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def display_and_save_batch(
+    batch: dict,
+    params: AttributeDict,
+    sp: spm.SentencePieceProcessor,
+) -> None:
+    """Display the batch statistics and save the batch into disk.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      params:
+        Parameters for training. See :func:`get_params`.
+      sp:
+        The BPE model.
+    """
+    from lhotse.utils import uuid4
+
+    filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+    logging.info(f"Saving batch to {filename}")
+    torch.save(batch, filename)
+
+    supervisions = batch["supervisions"]
+    features = batch["inputs"]
+
+    logging.info(f"features shape: {features.shape}")
+
+    y = sp.encode(supervisions["text"], out_type=int)
+    num_tokens = sum(len(i) for i in y)
+    logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+    model: Union[nn.Module, DDP],
+    train_dl: torch.utils.data.DataLoader,
+    optimizer: torch.optim.Optimizer,
+    sp: spm.SentencePieceProcessor,
+    params: AttributeDict,
+):
+    from lhotse.dataset import find_pessimistic_batches
+
+    logging.info(
+        "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+    )
+    batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+    for criterion, cuts in batches.items():
+        batch = train_dl.dataset[cuts]
+        try:
+            with torch.cuda.amp.autocast(enabled=params.use_fp16):
+                loss, _ = compute_loss(
+                    params=params,
+                    model=model,
+                    sp=sp,
+                    batch=batch,
+                    is_training=True,
+                )
+            loss.backward()
+            optimizer.zero_grad()
+        except Exception as e:
+            if "CUDA out of memory" in str(e):
+                logging.error(
+                    "Your GPU ran out of memory with the current "
+                    "max_duration setting. We recommend decreasing "
+                    "max_duration and trying again.\n"
+                    f"Failing criterion: {criterion} "
+                    f"(={crit_values[criterion]}) ..."
+                )
+            display_and_save_batch(batch, params=params, sp=sp)
+            raise
+        logging.info(
+            f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+        )
+
+
+def main():
+    parser = get_parser()
+    Xbmu_AmdoAsrDataModule.add_arguments(parser)
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/xbmu_amdo31/ASR/shared b/egs/xbmu_amdo31/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/xbmu_amdo31/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md
index 7257bad9a..38b491fc6 100644
--- a/egs/yesno/ASR/README.md
+++ b/egs/yesno/ASR/README.md
@@ -10,5 +10,5 @@ get the following WER:
 ```
 
 Please refer to
-
+
 for detailed instructions.
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index f83be05cf..e0a94bf08 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -78,10 +78,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
 
     logging.info("Removing disambiguation symbols on LG")
 
-    LG.labels[LG.labels >= first_token_disambig_id] = 0
-    # See https://github.com/k2-fsa/k2/issues/874
-    # for why we need to set LG.properties to None
-    LG.__dict__["_properties"] = None
+    # LG.labels[LG.labels >= first_token_disambig_id] = 0
+    # see https://github.com/k2-fsa/k2/pull/1140
+    labels = LG.labels
+    labels[labels >= first_token_disambig_id] = 0
+    LG.labels = labels
 
     assert isinstance(LG.aux_labels, k2.RaggedTensor)
     LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
@@ -128,9 +129,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py
index 9a4e8a36f..75d95df68 100755
--- a/egs/yesno/ASR/local/compute_fbank_yesno.py
+++ b/egs/yesno/ASR/local/compute_fbank_yesno.py
@@ -54,9 +54,7 @@ def compute_fbank_yesno():
         dataset_parts,
     )
 
-    extractor = Fbank(
-        FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)
-    )
+    extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins))
 
     with get_executor() as ex:  # Initialize the executor only once.
         for partition, m in manifests.items():
@@ -71,9 +69,7 @@ def compute_fbank_yesno():
             )
             if "train" in partition:
                 cut_set = (
-                    cut_set
-                    + cut_set.perturb_speed(0.9)
-                    + cut_set.perturb_speed(1.1)
+                    cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
                 )
             cut_set = cut_set.compute_and_store_features(
                 extractor=extractor,
@@ -87,9 +83,7 @@ def compute_fbank_yesno():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
 
diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh
index 8fcee0290..d4ef8d601 100755
--- a/egs/yesno/ASR/prepare.sh
+++ b/egs/yesno/ASR/prepare.sh
@@ -1,5 +1,8 @@
 #!/usr/bin/env bash
 
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
 set -eou pipefail
 
 stage=-1
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index 85e5f1358..3c1682fa1 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -121,7 +121,7 @@ class YesNoAsrDataModule(DataModule):
         group.add_argument(
             "--shuffle",
             type=str2bool,
-            default=True,
+            default=False,
             help="When enabled (=default), the examples will be "
             "shuffled for each epoch.",
         )
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index 9d4ab4b61..d5efb41df 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -201,9 +201,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -274,9 +272,7 @@ def main():
 
     logging.info(f"device: {device}")
 
-    HLG = k2.Fsa.from_dict(
-        torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
-    )
+    HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
     HLG = HLG.to(device)
     assert HLG.requires_grad is False
 
@@ -297,9 +293,7 @@ def main():
 
     if params.export:
         logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
-        torch.save(
-            {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
-        )
+        torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
         return
 
     model.to(device)
@@ -317,9 +311,7 @@ def main():
         word_table=lexicon.word_table,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 14220be19..65be77db1 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -53,9 +53,7 @@ def get_parser():
         help="Path to words.txt",
     )
 
-    parser.add_argument(
-        "--HLG", type=str, required=True, help="Path to HLG.pt."
-    )
+    parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
 
     parser.add_argument(
         "sound_files",
@@ -101,10 +99,9 @@ def read_sound_files(
     ans = []
     for f in filenames:
         wave, sample_rate = torchaudio.load(f)
-        assert sample_rate == expected_sample_rate, (
-            f"expected sample rate: {expected_sample_rate}. "
-            f"Given: {sample_rate}"
-        )
+        assert (
+            sample_rate == expected_sample_rate
+        ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
         # We use only the first channel
         ans.append(wave[0])
     return ans
@@ -159,9 +156,7 @@ def main():
     logging.info("Decoding started")
     features = fbank(waves)
 
-    features = pad_sequence(
-        features, batch_first=True, padding_value=math.log(1e-10)
-    )
+    features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
 
     # Note: We don't use key padding mask for attention during decoding
     with torch.no_grad():
@@ -201,9 +196,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py
index f32a27f35..335493491 100755
--- a/egs/yesno/ASR/tdnn/train.py
+++ b/egs/yesno/ASR/tdnn/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py
index 6714180db..7f13e417a 100755
--- a/egs/yesno/ASR/transducer/decode.py
+++ b/egs/yesno/ASR/transducer/decode.py
@@ -116,9 +116,7 @@ def decode_one_batch(
     # at entry, feature is (N, T, C)
     feature_lens = batch["supervisions"]["num_frames"].to(device)
 
-    encoder_out, encoder_out_lens = model.encoder(
-        x=feature, x_lens=feature_lens
-    )
+    encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
 
     hyps = []
     batch_size = encoder_out.size(0)
@@ -186,9 +184,7 @@ def decode_dataset(
         if batch_idx % 100 == 0:
             batch_str = f"{batch_idx}/{num_batches}"
 
-            logging.info(
-                f"batch {batch_str}, cuts processed until now is {num_cuts}"
-            )
+            logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
     return results
 
 
@@ -303,9 +299,7 @@ def main():
         model=model,
     )
 
-    save_results(
-        exp_dir=params.exp_dir, test_set_name="test_set", results=results
-    )
+    save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results)
 
     logging.info("Done!")
 
diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py
index deb92107d..88866ae81 100755
--- a/egs/yesno/ASR/transducer/train.py
+++ b/egs/yesno/ASR/transducer/train.py
@@ -430,9 +430,7 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             valid_info = compute_validation_loss(
diff --git a/icefall/__init__.py b/icefall/__init__.py
index 27ad74213..5d846b41d 100644
--- a/icefall/__init__.py
+++ b/icefall/__init__.py
@@ -8,6 +8,12 @@ from . import (
     utils
 )
 
+from .byte_utils import (
+    byte_decode,
+    byte_encode,
+    smart_byte_decode,
+)
+
 from .checkpoint import (
     average_checkpoints,
     find_checkpoints,
@@ -49,6 +55,7 @@ from .utils import (
     get_alignments,
     get_executor,
     get_texts,
+    is_cjk,
     is_jit_tracing,
     is_module_available,
     l1_norm,
@@ -64,7 +71,10 @@ from .utils import (
     store_transcripts,
     str2bool,
     subsequent_chunk_mask,
+    tokenize_by_CJK_char,
     write_error_stats,
 )
 
 from .ngram_lm import NgramLm, NgramLmStateCost
+
+from .lm_wrapper import LmScorer
diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py
index e76b7ea32..d9659c2dd 100644
--- a/icefall/bpe_graph_compiler.py
+++ b/icefall/bpe_graph_compiler.py
@@ -83,11 +83,12 @@ class BpeCtcTrainingGraphCompiler(object):
         Args:
           piece_ids:
             It is a list-of-list integer IDs.
-         modified:
+          modified:
            See :func:`k2.ctc_graph` for its meaning.
         Return:
           Return an FsaVec, which is the result of composing a
           CTC topology with linear FSAs constructed from the given
           piece IDs.
         """
-        return k2.ctc_graph(piece_ids, modified=modified, device=self.device)
+        graph = k2.ctc_graph(piece_ids, modified=modified, device=self.device)
+        return graph
diff --git a/icefall/byte_utils.py b/icefall/byte_utils.py
new file mode 100644
index 000000000..7ee84ad27
--- /dev/null
+++ b/icefall/byte_utils.py
@@ -0,0 +1,311 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py
+
+import re
+import unicodedata
+
+
+WHITESPACE_NORMALIZER = re.compile(r"\s+")
+SPACE = chr(32)
+SPACE_ESCAPE = chr(9601)
+
+PRINTABLE_BASE_CHARS = [
+    256,
+    257,
+    258,
+    259,
+    260,
+    261,
+    262,
+    263,
+    264,
+    265,
+    266,
+    267,
+    268,
+    269,
+    270,
+    271,
+    272,
+    273,
+    274,
+    275,
+    276,
+    277,
+    278,
+    279,
+    280,
+    281,
+    282,
+    283,
+    284,
+    285,
+    286,
+    287,
+    32,
+    33,
+    34,
+    35,
+    36,
+    37,
+    38,
+    39,
+    40,
+    41,
+    42,
+    43,
+    44,
+    45,
+    46,
+    47,
+    48,
+    49,
+    50,
+    51,
+    52,
+    53,
+    54,
+    55,
+    56,
+    57,
+    58,
+    59,
+    60,
+    61,
+    62,
+    63,
+    64,
+    65,
+    66,
+    67,
+    68,
+    69,
+    70,
+    71,
+    72,
+    73,
+    74,
+    75,
+    76,
+    77,
+    78,
+    79,
+    80,
+    81,
+    82,
+    83,
+    84,
+    85,
+    86,
+    87,
+    88,
+    89,
+    90,
+    91,
+    92,
+    93,
+    94,
+    95,
+    96,
+    97,
+    98,
+    99,
+    100,
+    101,
+    102,
+    103,
+    104,
+    105,
+    106,
+    107,
+    108,
+    109,
+    110,
+    111,
+    112,
+    113,
+    114,
+    115,
+    116,
+    117,
+    118,
+    119,
+    120,
+    121,
+    122,
+    123,
+    124,
+    125,
+    126,
+    288,
+    289,
+    290,
+    291,
+    292,
+    293,
+    294,
+    295,
+    296,
+    297,
+    298,
+    299,
+    300,
+    301,
+    302,
+    303,
+    304,
+    305,
+    308,
+    309,
+    310,
+    311,
+    312,
+    313,
+    314,
+    315,
+    316,
+    317,
+    318,
+    321,
+    322,
+    323,
+    324,
+    325,
+    326,
+    327,
+    328,
+    330,
+    331,
+    332,
+    333,
+    334,
+    335,
+    336,
+    337,
+    338,
+    339,
+    340,
+    341,
+    342,
+    343,
+    344,
+    345,
+    346,
+    347,
+    348,
+    349,
+    350,
+    351,
+    352,
+    353,
+    354,
+    355,
+    356,
+    357,
+    358,
+    359,
+    360,
+    361,
+    362,
+    363,
+    364,
+    365,
+    366,
+    367,
+    368,
+    369,
+    370,
+    371,
+    372,
+    373,
+    374,
+    375,
+    376,
+    377,
+    378,
+    379,
+    380,
+    381,
+    382,
+    384,
+    385,
+    386,
+    387,
+    388,
+    389,
+    390,
+    391,
+    392,
+    393,
+    394,
+    395,
+    396,
+    397,
+    398,
+    399,
+    400,
+    401,
+    402,
+    403,
+    404,
+    405,
+    406,
+    407,
+    408,
+    409,
+    410,
+    411,
+    412,
+    413,
+    414,
+    415,
+    416,
+    417,
+    418,
+    419,
+    420,
+    421,
+    422,
+]
+
+for c in PRINTABLE_BASE_CHARS:
+    assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c
+
+BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
+BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
+
+
+def byte_encode(x: str) -> str:
+    normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
+    return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
+
+
+def byte_decode(x: str) -> str:
+    try:
+        return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
+    except ValueError:
+        return ""
+
+
+def smart_byte_decode(x: str) -> str:
+    output = byte_decode(x)
+    if output == "":
+        # DP the best recovery (max valid chars) if it's broken
+        n_bytes = len(x)
+        f = [0 for _ in range(n_bytes + 1)]
+        pt = [0 for _ in range(n_bytes + 1)]
+        for i in range(1, n_bytes + 1):
+            f[i], pt[i] = f[i - 1], i - 1
+            for j in range(1, min(4, i) + 1):
+                if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
+                    f[i], pt[i] = f[i - j] + 1, i - j
+        cur_pt = n_bytes
+        while cur_pt > 0:
+            if f[cur_pt] == f[pt[cur_pt]] + 1:
+                output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
+            cur_pt = pt[cur_pt]
+    return output
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index 235160e14..5f9571d42 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -71,9 +71,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = re.sub(whitespace, "", text)
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -96,9 +94,7 @@ class CharCtcTrainingGraphCompiler(object):
         for text in texts:
             text = text.split("/")
             sub_ids = [
-                self.token_table[txt]
-                if txt in self.token_table
-                else self.oov_id
+                self.token_table[txt] if txt in self.token_table else self.oov_id
                 for txt in text
             ]
             ids.append(sub_ids)
@@ -121,4 +117,5 @@ class CharCtcTrainingGraphCompiler(object):
           CTC topology with linear FSAs constructed from the given
           piece IDs.
         """
-        return k2.ctc_graph(token_ids, modified=modified, device=self.device)
+        graph = k2.ctc_graph(token_ids, modified=modified, device=self.device)
+        return graph
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index 5069b78e8..c83c56a53 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -292,15 +292,19 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
     """
     checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
     pattern = re.compile(r"checkpoint-([0-9]+).pt")
-    iter_checkpoints = [
-        (int(pattern.search(c).group(1)), c) for c in checkpoints
-    ]
+    iter_checkpoints = []
+    for c in checkpoints:
+        result = pattern.search(c)
+        if not result:
+            logging.warn(f"Invalid checkpoint filename {c}")
+            continue
+
+        iter_checkpoints.append((int(result.group(1)), c))
+
     # iter_checkpoints is a list of tuples. Each tuple contains
     # two elements: (iteration_number, checkpoint-iteration_number.pt)
 
-    iter_checkpoints = sorted(
-        iter_checkpoints, reverse=True, key=lambda x: x[0]
-    )
+    iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
     if iteration >= 0:
         ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
     else:
@@ -469,7 +473,5 @@ def average_state_dict(
         v = state_dict_1[k]
         if torch.is_floating_point(v):
             v *= weight_1
-            v += (
-                state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
-            )
+            v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
             v *= scaling_factor
diff --git a/icefall/decode.py b/icefall/decode.py
index f04ee368c..23f9fb9b3 100644
--- a/icefall/decode.py
+++ b/icefall/decode.py
@@ -334,13 +334,9 @@ class Nbest(object):
         if hasattr(lattice, "aux_labels"):
             # delete token IDs as it is not needed
             del word_fsa.aux_labels
-            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
         else:
-            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
-                word_fsa
-            )
+            word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
 
         path_to_utt_map = self.shape.row_ids(1)
 
@@ -370,9 +366,7 @@ class Nbest(object):
         # path_lattice has word IDs as labels and token IDs as aux_labels
         path_lattice = k2.top_sort(k2.connect(path_lattice))
 
-        one_best = k2.shortest_path(
-            path_lattice, use_double_scores=use_double_scores
-        )
+        one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
 
         one_best = k2.invert(one_best)
         # Now one_best has token IDs as labels and word IDs as aux_labels
@@ -442,9 +436,7 @@ class Nbest(object):
         scores_shape = self.fsa.arcs.shape().remove_axis(1)
         # scores_shape has axes [path][arc]
 
-        ragged_scores = k2.RaggedTensor(
-            scores_shape, self.fsa.scores.contiguous()
-        )
+        ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous())
 
         tot_scores = ragged_scores.sum()
 
@@ -459,7 +451,8 @@ class Nbest(object):
 def one_best_decoding(
     lattice: k2.Fsa,
     use_double_scores: bool = True,
-) -> k2.Fsa:
+    lm_scale_list: Optional[List[float]] = None,
+) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
     """Get the best path from a lattice.
 
     Args:
@@ -468,11 +461,24 @@ def one_best_decoding(
       use_double_scores:
         True to use double precision floating point in the computation.
         False to use single precision.
+      lm_scale_list:
+        A list of floats representing LM score scales.
     Return:
       An FsaVec containing linear paths.
     """
-    best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
-    return best_path
+    if lm_scale_list is not None:
+        ans = dict()
+        saved_am_scores = lattice.scores - lattice.lm_scores
+        for lm_scale in lm_scale_list:
+            am_scores = saved_am_scores / lm_scale
+            lattice.scores = am_scores + lattice.lm_scores
+
+            best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
+            key = f"lm_scale_{lm_scale}"
+            ans[key] = best_path
+        return ans
+
+    return k2.shortest_path(lattice, use_double_scores=use_double_scores)
 
 
 def nbest_decoding(
@@ -678,9 +684,7 @@ def rescore_with_n_best_list(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
@@ -711,6 +715,107 @@ def rescore_with_n_best_list(
     return ans
 
 
+def nbest_rescore_with_LM(
+    lattice: k2.Fsa,
+    LM: k2.Fsa,
+    num_paths: int,
+    lm_scale_list: List[float],
+    nbest_scale: float = 1.0,
+    use_double_scores: bool = True,
+) -> Dict[str, k2.Fsa]:
+    """Rescore an n-best list with an n-gram LM.
+    The path with the maximum score is used as the decoding output.
+
+    Args:
+      lattice:
+        An FsaVec with axes [utt][state][arc]. It must have the following
+        attributes: ``aux_labels`` and ``lm_scores``. They are both token
+        IDs.
+      LM:
+        An FsaVec containing only a single FSA. It is one of follows:
+        - LG, L is lexicon and G is word-level n-gram LM.
+        - G, token-level n-gram LM.
+      num_paths:
+        Size of nbest list.
+      lm_scale_list:
+        A list of floats representing LM score scales.
+      nbest_scale:
+        Scale to be applied to ``lattice.score`` when sampling paths
+        using ``k2.random_paths``.
+      use_double_scores:
+        True to use double precision during computation. False to use
+        single precision.
+    Returns:
+      A dict of FsaVec, whose key is an lm_scale and the value is the
+      best decoding path for each utterance in the lattice.
+    """
+    device = lattice.device
+
+    assert len(lattice.shape) == 3
+    assert hasattr(lattice, "aux_labels")
+    assert hasattr(lattice, "lm_scores")
+
+    assert LM.shape == (1, None, None)
+    assert LM.device == device
+
+    nbest = Nbest.from_lattice(
+        lattice=lattice,
+        num_paths=num_paths,
+        use_double_scores=use_double_scores,
+        nbest_scale=nbest_scale,
+    )
+    # nbest.fsa.scores contains 0s
+
+    nbest = nbest.intersect(lattice)
+
+    # Now nbest.fsa has its scores set
+    assert hasattr(nbest.fsa, "lm_scores")
+
+    # am scores + bi-gram scores
+    hp_scores = nbest.tot_scores()
+
+    # Now start to intersect nbest with LG or G
+    inv_fsa = k2.invert(nbest.fsa)
+    if hasattr(LM, "aux_labels"):
+        # LM is LG here
+        # delete token IDs as it is not needed
+        del inv_fsa.aux_labels
+    inv_fsa.scores.zero_()
+    inv_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(inv_fsa)
+    path_to_utt_map = nbest.shape.row_ids(1)
+
+    LM = k2.arc_sort(LM)
+    path_lattice = k2.intersect_device(
+        LM,
+        inv_fsa_with_epsilon_loops,
+        b_to_a_map=torch.zeros_like(path_to_utt_map),
+        sorted_match_a=True,
+    )
+
+    # Its labels are token IDs.
+    # If LM is G, its aux_labels are tokens IDs;
+    # If LM is LG, its aux_labels are words IDs.
+    path_lattice = k2.top_sort(k2.connect(path_lattice))
+    one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores)
+
+    lm_scores = one_best.get_tot_scores(
+        use_double_scores=use_double_scores,
+        log_semiring=True,  # Note: we always use True
+    )
+    # If LM is LG, we might get empty paths
+    lm_scores[lm_scores == float("-inf")] = -1e9
+
+    ans = dict()
+    for lm_scale in lm_scale_list:
+        tot_scores = hp_scores.values / lm_scale + lm_scores
+        tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
+        max_indexes = tot_scores.argmax()
+        best_path = k2.index_fsa(nbest.fsa, max_indexes)
+        key = f"lm_scale_{lm_scale}"
+        ans[key] = best_path
+    return ans
+
+
 def rescore_with_whole_lattice(
     lattice: k2.Fsa,
     G_with_epsilon_loops: k2.Fsa,
@@ -787,13 +892,9 @@ def rescore_with_whole_lattice(
         except RuntimeError as e:
             logging.info(f"Caught exception:\n{e}\n")
             if loop_count >= max_loop_count:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
-            logging.info(
-                f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}")
             logging.info(
                 "This OOM is not an error. You can ignore it. "
                 "If your model does not converge well, or --max-duration "
@@ -805,9 +906,7 @@ def rescore_with_whole_lattice(
                 prune_th_list[loop_count],
                 True,
             )
-            logging.info(
-                f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
-            )
+            logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}")
         loop_count += 1
 
     # lat has token IDs as labels
@@ -894,9 +993,7 @@ def rescore_with_attention_decoder(
             logging.info(f"num_paths before decreasing: {num_paths}")
             num_paths = int(num_paths / 2)
             if loop_count >= max_loop_count or num_paths <= 0:
-                logging.info(
-                    "Return None as the resulting lattice is too large."
-                )
+                logging.info("Return None as the resulting lattice is too large.")
                 return None
             logging.info(
                 "This OOM is not an error. You can ignore it. "
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index b075aceac..6589579d1 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -19,7 +19,7 @@
 
 import random
 from dataclasses import dataclass
-from typing import Optional, Tuple, List
+from typing import List, Optional, Tuple
 
 import torch
 from torch import Tensor, nn
@@ -78,11 +78,11 @@ def get_tensor_stats(
     elif stats_type == "abs":
         x = x.abs()
     elif stats_type == "rms":
-        x = x ** 2
+        x = x**2
     elif stats_type == "positive":
         x = (x > 0).to(dtype=torch.float)
     else:
-        assert stats_type in [ "value", "max", "min" ]
+        assert stats_type in ["value", "max", "min"]
 
     sum_dims = [d for d in range(x.ndim) if d != dim]
     if len(sum_dims) > 0:
@@ -121,7 +121,9 @@ class TensorDiagnostic(object):
         self.name = name
         self.class_name = None  # will assign in accumulate()
 
-        self.stats = None  # we'll later assign a list to this data member.  It's a list of dict.
+        self.stats = (
+            None  # we'll later assign a list to this data member.  It's a list of dict.
+        )
 
         # the keys into self.stats[dim] are strings, whose values can be
         # "abs", "max", "min" ,"value", "positive", "rms", "value".
@@ -133,7 +135,6 @@ class TensorDiagnostic(object):
         # only adding a new element to the list if there was a different dim.
         # if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
 
-
     def accumulate(self, x, class_name: Optional[str] = None):
         """
         Accumulate tensors.
@@ -185,17 +186,12 @@ class TensorDiagnostic(object):
                         done = True
                         break
                 if not done:
-                    if (
-                        this_dim_stats[stats_type] != []
-                        and stats_type == "eigs"
-                    ):
+                    if this_dim_stats[stats_type] != [] and stats_type == "eigs":
                         # >1 size encountered on this dim, e.g. it's a batch or time dimension,
                         # don't accumulat "eigs" stats type, it uses too much memory
                         this_dim_stats[stats_type] = None
                     else:
-                        this_dim_stats[stats_type].append(
-                            TensorAndCount(stats, count)
-                        )
+                        this_dim_stats[stats_type].append(TensorAndCount(stats, count))
 
     def print_diagnostics(self):
         """Print diagnostics for each dimension of the tensor."""
@@ -211,7 +207,6 @@ class TensorDiagnostic(object):
                     assert stats_type == "eigs"
                     continue
 
-
                 def get_count(count):
                     return 1 if stats_type in ["max", "min"] else count
 
@@ -229,9 +224,7 @@ class TensorDiagnostic(object):
                         eigs, _ = torch.symeig(stats)
                         stats = eigs.abs().sqrt()
                     except:  # noqa
-                        print(
-                            "Error getting eigenvalues, trying another method."
-                        )
+                        print("Error getting eigenvalues, trying another method.")
                         eigs, _ = torch.eig(stats)
                         stats = eigs.abs().sqrt()
                         # sqrt so it reflects data magnitude, like stddev- not variance
@@ -242,9 +235,9 @@ class TensorDiagnostic(object):
 
                 # if `summarize` we print percentiles of the stats; else,
                 # we print out individual elements.
-                summarize = (
-                    len(stats_list) > 1
-                ) or self.opts.dim_is_summarized(stats.numel())
+                summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
+                    stats.numel()
+                )
                 if summarize:  # usually `summarize` will be true
                     # print out percentiles.
                     stats = stats.sort()[0]
@@ -261,33 +254,32 @@ class TensorDiagnostic(object):
                     ans = stats.tolist()
                     ans = ["%.2g" % x for x in ans]
                     ans = "[" + " ".join(ans) + "]"
-                if stats_type in [ "value", "rms", "eigs" ]:
+                if stats_type in ["value", "rms", "eigs"]:
                     # This norm is useful because it is strictly less than the largest
                     # sqrt(eigenvalue) of the variance, which we print out, and shows,
                     # speaking in an approximate way, how much of that largest eigenvalue
                     # can be attributed to the mean of the distribution.
-                    norm = (stats ** 2).sum().sqrt().item()
+                    norm = (stats**2).sum().sqrt().item()
                     ans += f", norm={norm:.2g}"
                 mean = stats.mean().item()
-                rms = (stats ** 2).mean().sqrt().item()
-                ans += f", mean={mean:.3g}, rms={rms:.3g}"
+                rms = (stats**2).mean().sqrt().item()
+                ans += f", mean={mean:.2g}, rms={rms:.2g}"
 
                 # OK, "ans" contains the actual stats, e.g.
                 # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5"
 
                 sizes = [x.tensor.shape[0] for x in stats_list]
                 size_str = (
-                    f"{sizes[0]}"
-                    if len(sizes) == 1
-                    else f"{min(sizes)}..{max(sizes)}"
+                    f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
+                )
+                maybe_class_name = (
+                    f" type={self.class_name}," if self.class_name is not None else ""
                 )
-                maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
                 print(
                     f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
                 )
 
 
-
 class ModelDiagnostic(object):
     """This class stores diagnostics for all tensors in the torch.nn.Module.
 
@@ -345,32 +337,32 @@ def attach_diagnostics(
         # (matters for name, since the variable gets overwritten).
         # These closures don't really capture by value, only by
         # "the final value the variable got in the function" :-(
-        def forward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
 
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.output"].accumulate(_output,
-                                                                class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.output"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
-                                                                         class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
-        def backward_hook(
-            _module, _input, _output, _model_diagnostic=ans, _name=name
-        ):
+        def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
             if isinstance(_output, tuple) and len(_output) == 1:
                 _output = _output[0]
             if isinstance(_output, Tensor):
-                _model_diagnostic[f"{_name}.grad"].accumulate(_output,
-                                                              class_name=type(_module).__name__)
+                _model_diagnostic[f"{_name}.grad"].accumulate(
+                    _output, class_name=type(_module).__name__
+                )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
-                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
-                                                                       class_name=type(_module).__name__)
+                    _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
+                        o, class_name=type(_module).__name__
+                    )
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
diff --git a/icefall/dist.py b/icefall/dist.py
index 7016beafb..922f31a2f 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -21,17 +21,19 @@ import torch
 from torch import distributed as dist
 
 
-def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
+def setup_dist(
+    rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None
+):
     """
     rank and world_size are used only if use_ddp_launch is False.
     """
     if "MASTER_ADDR" not in os.environ:
-        os.environ["MASTER_ADDR"] = "localhost"
+        os.environ["MASTER_ADDR"] = (
+            "localhost" if master_addr is None else str(master_addr)
+        )
 
     if "MASTER_PORT" not in os.environ:
-        os.environ["MASTER_PORT"] = (
-            "12354" if master_port is None else str(master_port)
-        )
+        os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
 
     if use_ddp_launch is False:
         dist.init_process_group("nccl", rank=rank, world_size=world_size)
diff --git a/icefall/env.py b/icefall/env.py
index 8aeda6be2..373e9a9ff 100644
--- a/icefall/env.py
+++ b/icefall/env.py
@@ -53,9 +53,7 @@ def get_git_sha1():
             )
             > 0
         )
-        git_commit = (
-            git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
-        )
+        git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
     except:  # noqa
         return None
 
diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py
index 570ed7d7a..d26ddbbd1 100644
--- a/icefall/graph_compiler.py
+++ b/icefall/graph_compiler.py
@@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object):
         lexicon: Lexicon,
         device: torch.device,
         oov: str = "",
+        need_repeat_flag: bool = False,
     ):
         """
         Args:
@@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object):
           oov:
             Out of vocabulary word. When a word in the transcript
             does not exist in the lexicon, it is replaced with `oov`.
+          need_repeat_flag:
+            If True, will add an attribute named `_is_repeat_token_` to ctc_topo
+            indicating whether this token is a repeat token in ctc graph.
+            This attribute is needed to implement delay-penalty for phone-based
+            ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
+            details. Note: The above change MUST be included in k2 to open this
+            flag.
         """
         L_inv = lexicon.L_inv.to(device)
         assert L_inv.requires_grad is False
@@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object):
         ctc_topo = k2.ctc_topo(max_token_id, modified=False)
 
         self.ctc_topo = ctc_topo.to(device)
+
+        if need_repeat_flag:
+            self.ctc_topo._is_repeat_token_ = (
+                self.ctc_topo.labels != self.ctc_topo.aux_labels
+            )
+
         self.device = device
 
     def compile(self, texts: List[str]) -> k2.Fsa:
@@ -75,9 +89,7 @@ class CtcTrainingGraphCompiler(object):
 
         # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
         # is False, so we add epsilon self-loops here
-        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
-            transcript_fsa
-        )
+        fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa)
 
         fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
 
diff --git a/icefall/hooks.py b/icefall/hooks.py
index fbcf5e148..398a5f689 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -14,10 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import random
+
 import torch
 from torch import Tensor, nn
-import logging
 
 
 def register_inf_check_hooks(model: nn.Module) -> None:
@@ -56,7 +57,7 @@ def register_inf_check_hooks(model: nn.Module) -> None:
             if isinstance(_output, Tensor):
                 if not torch.isfinite(_output.to(torch.float32).sum()):
                     logging.warning(
-                        f"The sum of {_name}.grad is not finite" # ": {_output}"
+                        f"The sum of {_name}.grad is not finite"  # ": {_output}"
                     )
             elif isinstance(_output, tuple):
                 for i, o in enumerate(_output):
@@ -65,28 +66,20 @@ def register_inf_check_hooks(model: nn.Module) -> None:
                     if not isinstance(o, Tensor):
                         continue
                     if not torch.isfinite(o.to(torch.float32).sum()):
-                        logging.warning(
-                            f"The sum of {_name}.grad[{i}] is not finite"
-                        )
+                        logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
 
         module.register_forward_hook(forward_hook)
         module.register_backward_hook(backward_hook)
 
-
     for name, parameter in model.named_parameters():
 
-        def param_backward_hook(
-                grad, _name=name
-        ):
+        def param_backward_hook(grad, _name=name):
             if not torch.isfinite(grad.to(torch.float32).sum()):
-                logging.warning(
-                    f"The sum of {_name}.param_grad is not finite"
-                )
+                logging.warning(f"The sum of {_name}.param_grad is not finite")
 
         parameter.register_hook(param_backward_hook)
 
 
-
 def _test_inf_check_hooks():
     model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
 
diff --git a/icefall/lexicon.py b/icefall/lexicon.py
index 80bd7c1ee..22e1b78bb 100644
--- a/icefall/lexicon.py
+++ b/icefall/lexicon.py
@@ -49,18 +49,12 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
                 continue
 
             if len(a) < 2:
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
-                logging.info(
-                    "Every line is expected to contain at least 2 fields"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
+                logging.info("Every line is expected to contain at least 2 fields")
                 sys.exit(1)
             word = a[0]
             if word == "":
-                logging.info(
-                    f"Found bad line {line} in lexicon file {filename}"
-                )
+                logging.info(f"Found bad line {line} in lexicon file {filename}")
                 logging.info(" should not be a valid word")
                 sys.exit(1)
 
@@ -119,9 +113,7 @@ def convert_lexicon_to_ragged(
     lexicon_tmp = read_lexicon(filename)
     lexicon = dict(lexicon_tmp)
     if len(lexicon_tmp) != len(lexicon):
-        raise RuntimeError(
-            "It's assumed that each word has a unique pronunciation"
-        )
+        raise RuntimeError("It's assumed that each word has a unique pronunciation")
 
     for i in range(disambig_id):
         w = word_table[i]
diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py
new file mode 100644
index 000000000..5e2783a47
--- /dev/null
+++ b/icefall/lm_wrapper.py
@@ -0,0 +1,254 @@
+# Copyright (c)  2022  Xiaomi Corporation (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+
+import torch
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.rnn_lm.model import RnnLmModel
+from icefall.transformer_lm.model import TransformerLM
+from icefall.utils import AttributeDict, str2bool
+
+
+class LmScorer(torch.nn.Module):
+    """This is a wrapper for NN LMs
+    The language models supported include:
+        RNN,
+        Transformer
+    """
+
+    def __init__(
+        self,
+        lm_type: str,
+        params: AttributeDict,
+        device,
+        lm_scale: float = 0.3,
+    ):
+        super(LmScorer, self).__init__()
+        assert lm_type in ["rnn", "transformer"], f"{lm_type} is not supported"
+        self.lm_type = lm_type
+        self.lm = self.get_lm(lm_type, device, params)
+        self.lm_scale = lm_scale
+        self.params = params
+
+    @classmethod
+    def add_arguments(cls, parser):
+        # LM general arguments
+        parser.add_argument(
+            "--lm-vocab-size",
+            type=int,
+            default=500,
+        )
+
+        parser.add_argument(
+            "--lm-epoch",
+            type=int,
+            default=7,
+            help="""Which epoch to be used
+            """,
+        )
+
+        parser.add_argument(
+            "--lm-avg",
+            type=int,
+            default=1,
+            help="""Number of checkpoints to be averaged
+            """,
+        )
+
+        parser.add_argument("--lm-exp-dir", type=str, help="Path to LM experiments")
+
+        # Now RNNLM related arguments
+        parser.add_argument(
+            "--rnn-lm-embedding-dim",
+            type=int,
+            default=2048,
+            help="Embedding dim of the model",
+        )
+
+        parser.add_argument(
+            "--rnn-lm-hidden-dim",
+            type=int,
+            default=2048,
+            help="Hidden dim of the model",
+        )
+
+        parser.add_argument(
+            "--rnn-lm-num-layers",
+            type=int,
+            default=3,
+            help="Number of RNN layers the model",
+        )
+
+        parser.add_argument(
+            "--rnn-lm-tie-weights",
+            type=str2bool,
+            default=True,
+            help="""True to share the weights between the input embedding layer and the
+            last output linear layer
+            """,
+        )
+
+        # Now transformers
+        parser.add_argument(
+            "--transformer-lm-exp-dir", type=str, help="Directory of transformer LM exp"
+        )
+
+        parser.add_argument(
+            "--transformer-lm-dim-feedforward",
+            type=int,
+            default=2048,
+            help="Dimension of FFW module in transformer",
+        )
+
+        parser.add_argument(
+            "--transformer-lm-encoder-dim",
+            type=int,
+            default=768,
+            help="Encoder dimension of transformer",
+        )
+
+        parser.add_argument(
+            "--transformer-lm-embedding-dim",
+            type=int,
+            default=768,
+            help="Input embedding dimension of transformer",
+        )
+
+        parser.add_argument(
+            "--transformer-lm-nhead",
+            type=int,
+            default=8,
+            help="Number of attention heads in transformer",
+        )
+
+        parser.add_argument(
+            "--transformer-lm-num-layers",
+            type=int,
+            default=16,
+            help="Number of encoder layers in transformer",
+        )
+
+        parser.add_argument(
+            "--transformer-lm-tie-weights",
+            type=str2bool,
+            default=True,
+            help="If tie weights in transformer LM",
+        )
+
+    def get_lm(self, lm_type: str, device, params: AttributeDict) -> torch.nn.Module:
+        """Return the neural network LM
+
+        Args:
+            lm_type (str): Type name of NN LM
+        """
+        if lm_type == "rnn":
+            model = RnnLmModel(
+                vocab_size=params.vocab_size,
+                embedding_dim=params.rnn_lm_embedding_dim,
+                hidden_dim=params.rnn_lm_hidden_dim,
+                num_layers=params.rnn_lm_num_layers,
+                tie_weights=params.rnn_lm_tie_weights,
+            )
+
+            if params.lm_avg == 1:
+                load_checkpoint(
+                    f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model
+                )
+                model.to(device)
+            else:
+                start = params.lm_epoch - params.lm_avg + 1
+                filenames = []
+                for i in range(start, params.lm_epoch + 1):
+                    if start >= 0:
+                        filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt")
+                logging.info(f"averaging {filenames}")
+                model.to(device)
+                model.load_state_dict(average_checkpoints(filenames, device=device))
+
+        elif lm_type == "transformer":
+            model = TransformerLM(
+                vocab_size=params.vocab_size,
+                d_model=params.transformer_lm_encoder_dim,
+                embedding_dim=params.transformer_lm_embedding_dim,
+                dim_feedforward=params.transformer_lm_dim_feedforward,
+                nhead=params.transformer_lm_nhead,
+                num_layers=params.transformer_lm_num_layers,
+                tie_weights=params.transformer_lm_tie_weights,
+                params=params,
+            )
+
+            if params.lm_avg == 1:
+                load_checkpoint(
+                    f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model
+                )
+                model.to(device)
+            else:
+                start = params.lm_epoch - params.lm_avg + 1
+                filenames = []
+                for i in range(start, params.lm_epoch + 1):
+                    if start >= 0:
+                        filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt")
+                logging.info(f"averaging {filenames}")
+                model.to(device)
+                model.load_state_dict(average_checkpoints(filenames, device=device))
+        else:
+            raise NotImplementedError()
+
+        return model
+
+    def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
+        """Score the input and return the prediction
+        This requires the lm to have the method `score_token`
+        Args:
+            x (torch.Tensor): Input tokens
+            x_lens (torch.Tensor): Length of the input tokens
+            state (optional): LM states
+
+        """
+        return self.lm.score_token(x, x_lens, state)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    LmScorer.add_arguments(parser)
+    args = parser.parse_args()
+
+    params = AttributeDict()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    Scorer = LmScorer(params=params, device=device)
+    Scorer.eval()
+
+    x = (
+        torch.tensor([[1, 4, 19, 256, 77], [1, 4, 19, 256, 77]])
+        .to(device)
+        .to(torch.int64)
+    )
+    x_lens = torch.tensor([5, 5]).to(device)
+
+    state = None
+
+    score, state = Scorer.score(x, x_lens)
+    print(score.shape)
+    print(score[0])
+    print(score[1])
diff --git a/icefall/mmi.py b/icefall/mmi.py
index 2c479fc2c..b7777b434 100644
--- a/icefall/mmi.py
+++ b/icefall/mmi.py
@@ -63,10 +63,7 @@ def _compute_mmi_loss_exact_optimized(
 
     # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
     num_den_graphs_indexes = (
-        torch.stack([num_graphs_indexes, den_graphs_indexes])
-        .t()
-        .reshape(-1)
-        .to(device)
+        torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
     )
 
     num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
@@ -116,19 +113,15 @@ def _compute_mmi_loss_exact_non_optimized(
 
     # TODO: pass output_beam as function argument
     num_lats = k2.intersect_dense(
-        num_graphs, dense_fsa_vec, output_beam=beam_size
+        num_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
     )
     den_lats = k2.intersect_dense(
-        den_graphs, dense_fsa_vec, output_beam=beam_size
+        den_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
@@ -155,7 +148,7 @@ def _compute_mmi_loss_pruned(
     """
     num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
 
-    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
+    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
 
     # the values for search_beam/output_beam/min_active_states/max_active_states
     # are not tuned. You may want to tune them.
@@ -168,13 +161,9 @@ def _compute_mmi_loss_pruned(
         max_active_states=10000,
     )
 
-    num_tot_scores = num_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
-    den_tot_scores = den_lats.get_tot_scores(
-        log_semiring=True, use_double_scores=True
-    )
+    den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
 
     tot_scores = num_tot_scores - den_scale * den_tot_scores
 
diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py
index 0d901227d..600f09f2b 100644
--- a/icefall/mmi_graph_compiler.py
+++ b/icefall/mmi_graph_compiler.py
@@ -74,7 +74,9 @@ class MmiTrainingGraphCompiler(object):
         # CAUTION: The following line is crucial.
         # Arcs entering the back-off state have label equal to #0.
         # We have to change it to 0 here.
-        P.labels[P.labels >= first_token_disambig_id] = 0
+        labels = P.labels.clone()
+        labels[labels >= first_token_disambig_id] = 0
+        P.labels = labels
 
         P = k2.remove_epsilon(P)
         P = k2.arc_sort(P)
@@ -137,9 +139,7 @@ class MmiTrainingGraphCompiler(object):
             transcript_fsa
         )
 
-        transcript_fsa_with_self_loops = k2.arc_sort(
-            transcript_fsa_with_self_loops
-        )
+        transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops)
 
         num = k2.compose(
             self.ctc_topo_P,
@@ -155,9 +155,7 @@ class MmiTrainingGraphCompiler(object):
 
         ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
         if replicate_den:
-            indexes = torch.zeros(
-                len(texts), dtype=torch.int32, device=self.device
-            )
+            indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device)
             den = k2.index_fsa(ctc_topo_P_vec, indexes)
         else:
             den = ctc_topo_P_vec
diff --git a/icefall/rnn_lm/.gitignore b/icefall/rnn_lm/.gitignore
new file mode 100644
index 000000000..877fb1e18
--- /dev/null
+++ b/icefall/rnn_lm/.gitignore
@@ -0,0 +1 @@
+icefall-librispeech-rnn-lm
diff --git a/icefall/rnn_lm/__init__.py b/icefall/rnn_lm/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py
new file mode 100755
index 000000000..8850c1c71
--- /dev/null
+++ b/icefall/rnn_lm/check-onnx-streaming.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation
+
+"""
+Usage:
+
+./check-onnx-streaming.py \
+  --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
+  --onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx
+
+Note: You can download pre-trained models from
+https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+
+"""
+
+import argparse
+import logging
+from typing import Tuple
+
+import onnxruntime as ort
+import torch
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--jit",
+        required=True,
+        type=str,
+        help="Path to the torchscript model",
+    )
+
+    parser.add_argument(
+        "--onnx",
+        required=True,
+        type=str,
+        help="Path to the onnx model",
+    )
+
+    return parser
+
+
+class OnnxModel:
+    def __init__(self, filename: str):
+        session_opts = ort.SessionOptions()
+        session_opts.inter_op_num_threads = 1
+        session_opts.intra_op_num_threads = 1
+
+        self.model = ort.InferenceSession(
+            filename,
+            sess_options=session_opts,
+        )
+
+        meta_data = self.model.get_modelmeta().custom_metadata_map
+        self.sos_id = int(meta_data["sos_id"])
+        self.eos_id = int(meta_data["eos_id"])
+        self.vocab_size = int(meta_data["vocab_size"])
+        self.num_layers = int(meta_data["num_layers"])
+        self.hidden_size = int(meta_data["hidden_size"])
+        print(meta_data)
+
+    def __call__(
+        self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        out = self.model.run(
+            [
+                self.model.get_outputs()[0].name,
+                self.model.get_outputs()[1].name,
+                self.model.get_outputs()[2].name,
+            ],
+            {
+                self.model.get_inputs()[0].name: x.numpy(),
+                self.model.get_inputs()[1].name: y.numpy(),
+                self.model.get_inputs()[2].name: h0.numpy(),
+                self.model.get_inputs()[3].name: c0.numpy(),
+            },
+        )
+        return (
+            torch.from_numpy(out[0]),
+            torch.from_numpy(out[1]),
+            torch.from_numpy(out[2]),
+        )
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    logging.info(vars(args))
+
+    torch_model = torch.jit.load(args.jit).cpu()
+    onnx_model = OnnxModel(args.onnx)
+    N = torch.arange(1, 5).tolist()
+
+    num_layers = onnx_model.num_layers
+    hidden_size = onnx_model.hidden_size
+
+    for n in N:
+        L = torch.randint(low=1, high=100, size=(1,)).item()
+        x = torch.randint(
+            low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
+        )
+        y = torch.randint(
+            low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
+        )
+        h0 = torch.rand(num_layers, n, hidden_size)
+        c0 = torch.rand(num_layers, n, hidden_size)
+
+        torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0)
+        onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0)
+
+        for torch_v, onnx_v in zip(
+            (torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0)
+        ):
+
+            assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
+                torch_v.shape,
+                onnx_v.shape,
+                (torch_v - onnx_v).abs().max(),
+            )
+            print(n, L, torch_v.sum(), onnx_v.sum())
+
+
+if __name__ == "__main__":
+    torch.manual_seed(20230423)
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/icefall/rnn_lm/check-onnx.py b/icefall/rnn_lm/check-onnx.py
new file mode 100755
index 000000000..24c5395f8
--- /dev/null
+++ b/icefall/rnn_lm/check-onnx.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation
+
+"""
+Usage:
+
+./check-onnx.py \
+  --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
+  --onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx
+
+Note: You can download pre-trained models from
+https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+
+"""
+
+import argparse
+import logging
+
+import onnxruntime as ort
+import torch
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--jit",
+        required=True,
+        type=str,
+        help="Path to the torchscript model",
+    )
+
+    parser.add_argument(
+        "--onnx",
+        required=True,
+        type=str,
+        help="Path to the onnx model",
+    )
+
+    return parser
+
+
+class OnnxModel:
+    def __init__(self, filename: str):
+        session_opts = ort.SessionOptions()
+        session_opts.inter_op_num_threads = 1
+        session_opts.intra_op_num_threads = 1
+
+        self.model = ort.InferenceSession(
+            filename,
+            sess_options=session_opts,
+        )
+
+        meta_data = self.model.get_modelmeta().custom_metadata_map
+        self.sos_id = int(meta_data["sos_id"])
+        self.eos_id = int(meta_data["eos_id"])
+        self.vocab_size = int(meta_data["vocab_size"])
+        print(meta_data)
+
+    def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
+        out = self.model.run(
+            [
+                self.model.get_outputs()[0].name,
+            ],
+            {
+                self.model.get_inputs()[0].name: x.numpy(),
+                self.model.get_inputs()[1].name: x_lens.numpy(),
+            },
+        )
+        return torch.from_numpy(out[0])
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    logging.info(vars(args))
+
+    torch_model = torch.jit.load(args.jit).cpu()
+    onnx_model = OnnxModel(args.onnx)
+    N = torch.arange(1, 5).tolist()
+
+    for n in N:
+        L = torch.randint(low=1, high=100, size=(1,)).item()
+        x = torch.randint(
+            low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
+        )
+        x_lens = torch.full((n,), fill_value=L, dtype=torch.int64)
+        if n > 1:
+            x_lens[0] = L // 2 + 1
+
+        sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1)
+        sos_x = torch.cat([sos, x], dim=1)
+
+        pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1)
+        x_eos = torch.cat([x, pad_col], dim=1)
+
+        row_index = torch.arange(0, n, dtype=x.dtype)
+        x_eos[row_index, x_lens] = onnx_model.eos_id
+
+        torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1)
+        onnx_nll = onnx_model(x, x_lens)
+        # Note: For int8 models, the differences may be quite large,
+        # e.g., within 0.9
+        assert torch.allclose(torch_nll, onnx_nll), (
+            torch_nll,
+            onnx_nll,
+        )
+        print(n, L, torch_nll, onnx_nll)
+
+
+if __name__ == "__main__":
+    torch.manual_seed(20230420)
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py
index 550801a8f..cc566bd92 100755
--- a/icefall/rnn_lm/compute_perplexity.py
+++ b/icefall/rnn_lm/compute_perplexity.py
@@ -20,7 +20,7 @@ Usage:
   ./rnn_lm/compute_perplexity.py \
     --epoch 4 \
     --avg 2 \
-    --lm-data ./data/bpe_500/sorted_lm_data-test.pt
+    --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt
 
 """
 
@@ -33,7 +33,7 @@ import torch
 from dataset import get_dataloader
 from model import RnnLmModel
 
-from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
 from icefall.utils import AttributeDict, setup_logger, str2bool
 
 
@@ -49,6 +49,7 @@ def get_parser():
         help="It specifies the checkpoint to use for decoding."
         "Note: Epoch counts from 0.",
     )
+
     parser.add_argument(
         "--avg",
         type=int,
@@ -58,6 +59,16 @@ def get_parser():
         "'--epoch'. ",
     )
 
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
     parser.add_argument(
         "--exp-dir",
         type=str,
@@ -154,7 +165,14 @@ def main():
 
     params = AttributeDict(vars(args))
 
-    setup_logger(f"{params.exp_dir}/log-ppl/")
+    if params.iter > 0:
+        setup_logger(
+            f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}"
+        )
+    else:
+        setup_logger(
+            f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}"
+        )
     logging.info("Computing perplexity started")
     logging.info(params)
 
@@ -173,19 +191,39 @@ def main():
         tie_weights=params.tie_weights,
     )
 
-    if params.avg == 1:
-        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    if params.iter > 0:
+        filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+            : params.avg
+        ]
+        if len(filenames) == 0:
+            raise ValueError(
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+            )
+        elif len(filenames) < params.avg:
+            raise ValueError(
+                f"Not enough checkpoints ({len(filenames)}) found for"
+                f" --iter {params.iter}, --avg {params.avg}"
+            )
+        logging.info(f"averaging {filenames}")
         model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+    elif params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
     else:
         start = params.epoch - params.avg + 1
         filenames = []
         for i in range(start, params.epoch + 1):
-            if start >= 0:
+            if i >= 0:
                 filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
         logging.info(f"averaging {filenames}")
         model.to(device)
-        model.load_state_dict(average_checkpoints(filenames, device=device))
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
 
+    model.to(device)
     model.eval()
     num_param = sum([p.numel() for p in model.parameters()])
     num_param_requires_grad = sum(
diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py
index 598e329c4..53be53f64 100644
--- a/icefall/rnn_lm/dataset.py
+++ b/icefall/rnn_lm/dataset.py
@@ -1,4 +1,4 @@
-# Copyright (c)  2021  Xiaomi Corporation (authors: Fangjun Kuang)
+# Copyright (c)  2021  Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang)
 #
 # See ../../../../LICENSE for clarification regarding multiple authors
 #
@@ -155,12 +155,8 @@ class LmDatasetCollate:
         sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
         sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
 
-        x = sentence_tokens_with_sos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
-        y = sentence_tokens_with_eos.pad(
-            mode="constant", padding_value=self.blank_id
-        )
+        x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
+        y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
         sentence_token_lengths += 1  # plus 1 since we added a SOS
 
         return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
@@ -198,7 +194,7 @@ def get_dataloader(
         batch_size=params.batch_size,
     )
     if is_distributed:
-        sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
+        sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
     else:
         sampler = None
 
diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py
new file mode 100755
index 000000000..1d9af5e3d
--- /dev/null
+++ b/icefall/rnn_lm/export-onnx.py
@@ -0,0 +1,397 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation
+
+import argparse
+import logging
+from pathlib import Path
+
+import onnx
+import torch
+from model import RnnLmModel
+from onnxruntime.quantization import QuantType, quantize_dynamic
+
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.utils import AttributeDict, str2bool
+from typing import Dict
+from train import get_params
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+    """Add meta data to an ONNX model. It is changed in-place.
+
+    Args:
+      filename:
+        Filename of the ONNX model to be changed.
+      meta_data:
+        Key-value pairs.
+    """
+    model = onnx.load(filename)
+    for key, value in meta_data.items():
+        meta = model.metadata_props.add()
+        meta.key = key
+        meta.value = value
+
+    onnx.save(model, filename)
+
+
+# A wrapper for RnnLm model to simpily the C++ calling code
+# when exporting the model to ONNX.
+#
+# TODO(fangjun): The current wrapper works only for non-streaming ASR
+# since we don't expose the LM state and it is used to score
+# a complete sentence at once.
+class RnnLmModelWrapper(torch.nn.Module):
+    def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int):
+        super().__init__()
+        self.model = model
+        self.sos_id = sos_id
+        self.eos_id = eos_id
+
+    def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+          x:
+            A 2-D tensor of shape (N, L) with dtype torch.int64.
+            It does not contain SOS or EOS. We will add SOS and EOS inside
+            this function.
+          x_lens:
+            A 1-D tensor of shape (N,) with dtype torch.int64. It contains
+            number of valid tokens in ``x`` before padding.
+        Returns:
+          Return a 1-D tensor of shape (N,) containing negative loglikelihood.
+          Its dtype is torch.float32
+        """
+        N = x.size(0)
+
+        sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand(
+            N, 1
+        )
+        sos_x = torch.cat([sos_tensor, x], dim=1)
+
+        pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1)
+        x_eos = torch.cat([x, pad_col], dim=1)
+
+        row_index = torch.arange(0, N, dtype=x.dtype)
+        x_eos[row_index, x_lens] = self.eos_id
+
+        # use x_lens + 1 here since we prepended x with sos
+        return (
+            self.model(x=sos_x, y=x_eos, lengths=x_lens + 1)
+            .to(torch.float32)
+            .sum(dim=1)
+        )
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=29,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=5,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--vocab-size",
+        type=int,
+        default=500,
+        help="Vocabulary size of the model",
+    )
+
+    parser.add_argument(
+        "--embedding-dim",
+        type=int,
+        default=2048,
+        help="Embedding dim of the model",
+    )
+
+    parser.add_argument(
+        "--hidden-dim",
+        type=int,
+        default=2048,
+        help="Hidden dim of the model",
+    )
+
+    parser.add_argument(
+        "--num-layers",
+        type=int,
+        default=3,
+        help="Number of RNN layers the model",
+    )
+
+    parser.add_argument(
+        "--tie-weights",
+        type=str2bool,
+        default=True,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="rnn_lm/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    return parser
+
+
+def export_without_state(
+    model: RnnLmModel,
+    filename: str,
+    params: AttributeDict,
+    opset_version: int,
+):
+    model_wrapper = RnnLmModelWrapper(
+        model,
+        sos_id=params.sos_id,
+        eos_id=params.eos_id,
+    )
+
+    N = 1
+    L = 20
+    x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
+    x_lens = torch.full((N,), fill_value=L, dtype=torch.int64)
+
+    # Note(fangjun): The following warnings can be ignored.
+    # We can use ./check-onnx.py to validate the exported model with batch_size > 1
+    """
+    torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
+    with a batch_size other than 1, with a variable length with LSTM can cause
+    an error when running the ONNX model with a different batch size. Make sure
+    to save the model with a batch size of 1, or define the initial states
+    (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
+    with a batch_size other than 1, " +
+    """
+
+    torch.onnx.export(
+        model_wrapper,
+        (x, x_lens),
+        filename,
+        verbose=False,
+        opset_version=opset_version,
+        input_names=["x", "x_lens"],
+        output_names=["nll"],
+        dynamic_axes={
+            "x": {0: "N", 1: "L"},
+            "x_lens": {0: "N"},
+            "nll": {0: "N"},
+        },
+    )
+
+    meta_data = {
+        "model_type": "rnnlm",
+        "version": "1",
+        "model_author": "k2-fsa",
+        "comment": "rnnlm without state",
+        "sos_id": str(params.sos_id),
+        "eos_id": str(params.eos_id),
+        "vocab_size": str(params.vocab_size),
+        "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
+    }
+    logging.info(f"meta_data: {meta_data}")
+
+    add_meta_data(filename=filename, meta_data=meta_data)
+
+
+def export_with_state(
+    model: RnnLmModel,
+    filename: str,
+    params: AttributeDict,
+    opset_version: int,
+):
+    N = 1
+    L = 20
+    num_layers = model.rnn.num_layers
+    hidden_size = model.rnn.hidden_size
+    embedding_dim = model.embedding_dim
+
+    x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
+    y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
+    h0 = torch.zeros(num_layers, N, hidden_size)
+    c0 = torch.zeros(num_layers, N, hidden_size)
+
+    # Note(fangjun): The following warnings can be ignored.
+    # We can use ./check-onnx.py to validate the exported model with batch_size > 1
+    """
+    torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
+    with a batch_size other than 1, with a variable length with LSTM can cause
+    an error when running the ONNX model with a different batch size. Make sure
+    to save the model with a batch size of 1, or define the initial states
+    (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
+    with a batch_size other than 1, " +
+    """
+
+    torch.onnx.export(
+        model,
+        (x, y, h0, c0),
+        filename,
+        verbose=False,
+        opset_version=opset_version,
+        input_names=["x", "y", "h0", "c0"],
+        output_names=["nll", "next_h0", "next_c0"],
+        dynamic_axes={
+            "x": {0: "N", 1: "L"},
+            "y": {0: "N", 1: "L"},
+            "h0": {1: "N"},
+            "c0": {1: "N"},
+            "nll": {0: "N"},
+            "next_h0": {1: "N"},
+            "next_c0": {1: "N"},
+        },
+    )
+
+    meta_data = {
+        "model_type": "rnnlm",
+        "version": "1",
+        "model_author": "k2-fsa",
+        "comment": "rnnlm state",
+        "sos_id": str(params.sos_id),
+        "eos_id": str(params.eos_id),
+        "vocab_size": str(params.vocab_size),
+        "num_layers": str(num_layers),
+        "hidden_size": str(hidden_size),
+        "embedding_dim": str(embedding_dim),
+        "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
+    }
+    logging.info(f"meta_data: {meta_data}")
+
+    add_meta_data(filename=filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    logging.info(params)
+
+    device = torch.device("cpu")
+    logging.info(f"device: {device}")
+
+    model = RnnLmModel(
+        vocab_size=params.vocab_size,
+        embedding_dim=params.embedding_dim,
+        hidden_dim=params.hidden_dim,
+        num_layers=params.num_layers,
+        tie_weights=params.tie_weights,
+    )
+
+    model.to(device)
+
+    if params.iter > 0:
+        filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+            : params.avg
+        ]
+        if len(filenames) == 0:
+            raise ValueError(
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+            )
+        elif len(filenames) < params.avg:
+            raise ValueError(
+                f"Not enough checkpoints ({len(filenames)}) found for"
+                f" --iter {params.iter}, --avg {params.avg}"
+            )
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+    elif params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if i >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.iter > 0:
+        suffix = f"iter-{params.iter}"
+    else:
+        suffix = f"epoch-{params.epoch}"
+
+    suffix += f"-avg-{params.avg}"
+
+    opset_version = 13
+
+    logging.info("Exporting model without state")
+    filename = params.exp_dir / f"no-state-{suffix}.onnx"
+    export_without_state(
+        model=model,
+        filename=filename,
+        params=params,
+        opset_version=opset_version,
+    )
+
+    filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx"
+    quantize_dynamic(
+        model_input=filename,
+        model_output=filename_int8,
+        weight_type=QuantType.QInt8,
+    )
+
+    # now for streaming export
+    saved_forward = model.__class__.forward
+    model.__class__.forward = model.__class__.streaming_forward
+    streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
+    export_with_state(
+        model=model,
+        filename=streaming_filename,
+        params=params,
+        opset_version=opset_version,
+    )
+    model.__class__.forward = saved_forward
+
+    streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx"
+    quantize_dynamic(
+        model_input=streaming_filename,
+        model_output=streaming_filename_int8,
+        weight_type=QuantType.QInt8,
+    )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/icefall/rnn_lm/export-onnx.sh b/icefall/rnn_lm/export-onnx.sh
new file mode 100755
index 000000000..6e3262b5e
--- /dev/null
+++ b/icefall/rnn_lm/export-onnx.sh
@@ -0,0 +1,26 @@
+#!/usr/bin/env bash
+
+# We use the model from
+# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
+# as an example
+
+export CUDA_VISIBLE_DEVICES=
+
+if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
+  GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+  pushd icefall-librispeech-rnn-lm/exp
+  git lfs pull --include "pretrained.pt"
+  ln -s pretrained.pt epoch-99.pt
+  popd
+fi
+
+python3 ./export-onnx.py \
+  --exp-dir ./icefall-librispeech-rnn-lm/exp \
+  --epoch 99 \
+  --avg 1 \
+  --vocab-size 500 \
+  --embedding-dim 2048 \
+  --hidden-dim 2048 \
+  --num-layers 3 \
+  --tie-weights 1
+
diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py
index 094035fce..be4e7f8c5 100644
--- a/icefall/rnn_lm/export.py
+++ b/icefall/rnn_lm/export.py
@@ -25,8 +25,8 @@ from pathlib import Path
 import torch
 from model import RnnLmModel
 
-from icefall.checkpoint import load_checkpoint
-from icefall.utils import AttributeDict, load_averaged_model, str2bool
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
+from icefall.utils import AttributeDict, str2bool
 
 
 def get_parser():
@@ -51,6 +51,16 @@ def get_parser():
         "'--epoch'. ",
     )
 
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
     parser.add_argument(
         "--vocab-size",
         type=int,
@@ -108,6 +118,7 @@ def get_parser():
     return parser
 
 
+@torch.no_grad()
 def main():
     args = get_parser().parse_args()
     args.exp_dir = Path(args.exp_dir)
@@ -133,11 +144,36 @@ def main():
 
     model.to(device)
 
-    if params.avg == 1:
+    if params.iter > 0:
+        filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+            : params.avg
+        ]
+        if len(filenames) == 0:
+            raise ValueError(
+                f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+            )
+        elif len(filenames) < params.avg:
+            raise ValueError(
+                f"Not enough checkpoints ({len(filenames)}) found for"
+                f" --iter {params.iter}, --avg {params.avg}"
+            )
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
+        )
+    elif params.avg == 1:
         load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
     else:
-        model = load_averaged_model(
-            params.exp_dir, model, params.epoch, params.avg, device
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if i >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(
+            average_checkpoints(filenames, device=device), strict=False
         )
 
     model.to("cpu")
@@ -145,6 +181,10 @@ def main():
 
     if params.jit:
         logging.info("Using torch.jit.script")
+
+        model.__class__.streaming_forward = torch.jit.export(
+            model.__class__.streaming_forward
+        )
         model = torch.jit.script(model)
         filename = params.exp_dir / "cpu_jit.pt"
         model.save(str(filename))
@@ -159,9 +199,7 @@ def main():
 
 
 if __name__ == "__main__":
-    formatter = (
-        "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    )
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
 
     logging.basicConfig(format=formatter, level=logging.INFO)
     main()
diff --git a/icefall/rnn_lm/export.sh b/icefall/rnn_lm/export.sh
new file mode 100755
index 000000000..678bc294e
--- /dev/null
+++ b/icefall/rnn_lm/export.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+# We use the model from
+# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
+# as an example
+
+export CUDA_VISIBLE_DEVICES=
+
+if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
+  GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+  pushd icefall-librispeech-rnn-lm/exp
+  git lfs pull --include "pretrained.pt"
+  ln -s pretrained.pt epoch-99.pt
+  popd
+fi
+
+python3 ./export.py \
+  --exp-dir ./icefall-librispeech-rnn-lm/exp \
+  --epoch 99 \
+  --avg 1 \
+  --vocab-size 500 \
+  --embedding-dim 2048 \
+  --hidden-dim 2048 \
+  --num-layers 3 \
+  --tie-weights 1 \
+  --jit 1
+
diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py
index a6144727a..a8eaadc0c 100644
--- a/icefall/rnn_lm/model.py
+++ b/icefall/rnn_lm/model.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import Tuple
 
 import torch
 import torch.nn.functional as F
@@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module):
             and https://arxiv.org/abs/1611.01462
         """
         super().__init__()
+        self.vocab_size = vocab_size
+        self.embedding_dim = embedding_dim
+        self.hidden_dim = hidden_dim
+        self.num_layers = num_layers
+        self.tie_weights = tie_weights
 
         self.input_embedding = torch.nn.Embedding(
             num_embeddings=vocab_size,
@@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module):
 
         self.cache = {}
 
+    def streaming_forward(
+        self,
+        x: torch.Tensor,
+        y: torch.Tensor,
+        h0: torch.Tensor,
+        c0: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Args:
+          x:
+            A 2-D tensor of shape (N, L). We won't prepend it with SOS.
+          y:
+            A 2-D tensor of shape (N, L). We won't append it with EOS.
+          h0:
+            A 3-D tensor of shape (num_layers, N, hidden_size).
+            (If proj_size > 0, then it is (num_layers, N, proj_size))
+          c0:
+            A 3-D tensor of shape (num_layers, N, hidden_size).
+        Returns:
+          Return a tuple containing 3 tensors:
+            - negative loglike (nll), a 1-D tensor of shape (N,)
+            - next_h0, a 3-D tensor with the same shape as h0
+            - next_c0, a 3-D tensor with the same shape as c0
+        """
+        assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
+        assert x.shape == y.shape, (x.shape, y.shape)
+
+        # embedding is of shape (N, L, embedding_dim)
+        embedding = self.input_embedding(x)
+        # Note: We use batch_first==True
+        rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0))
+        logits = self.output_linear(rnn_out)
+        nll_loss = F.cross_entropy(
+            logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
+        )
+
+        batch_size = x.size(0)
+        nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1)
+        return nll_loss, next_h0, next_c0
+
     def forward(
         self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
     ) -> torch.Tensor:
@@ -129,9 +175,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
@@ -155,20 +199,36 @@ class RnnLmModel(torch.nn.Module):
     def clean_cache(self):
         self.cache = {}
 
-    def score_token(self, tokens: torch.Tensor, state=None):
+    def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
+        """Score a batch of tokens, i.e each sample in the batch should be a
+        single token. For example, x = torch.tensor([[5],[10],[20]])
+
+
+        Args:
+            x (torch.Tensor):
+                A batch of tokens
+            x_lens (torch.Tensor):
+                The length of tokens in the batch before padding
+            state (optional):
+                Either None or a tuple of two torch.Tensor. Each tensor has
+                the shape of (num_layers, bs, hidden_dim)
+
+        Returns:
+            _type_: _description_
+        """
         device = next(self.parameters()).device
-        batch_size = tokens.size(0)
+        batch_size = x.size(0)
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            ).to(device)
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
+                device
+            )
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
+                device
+            )
 
-        embedding = self.input_embedding(tokens)
+        embedding = self.input_embedding(x)
         rnn_out, states = self.rnn(embedding, (h, c))
         logits = self.output_linear(rnn_out)
 
@@ -181,12 +241,8 @@ class RnnLmModel(torch.nn.Module):
         if state:
             h, c = state
         else:
-            h = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
-            c = torch.zeros(
-                self.rnn.num_layers, batch_size, self.rnn.input_size
-            )
+            h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
+            c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
 
         device = next(self.parameters()).device
 
@@ -194,9 +250,7 @@ class RnnLmModel(torch.nn.Module):
         tokens_eos = add_eos(tokens, eos_id)
         sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
 
-        sentence_lengths = (
-            sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
-        )
+        sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
 
         x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
         y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index bb5f03fb9..91df4f921 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -24,7 +24,7 @@ Usage:
         --use-fp16 0 \
         --embedding-dim 800 \
         --hidden-dim 200 \
-        --num-layers 2\
+        --num-layers 2 \
         --batch-size 400
 
 """
@@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter
 
 from icefall.checkpoint import load_checkpoint
 from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import save_checkpoint_with_global_batch_idx
 from icefall.dist import cleanup_dist, setup_dist
 from icefall.env import get_env_info
 from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -83,7 +84,7 @@ def get_parser():
     parser.add_argument(
         "--num-epochs",
         type=int,
-        default=10,
+        default=30,
         help="Number of epochs to train.",
     )
 
@@ -110,14 +111,14 @@ def get_parser():
     parser.add_argument(
         "--use-fp16",
         type=str2bool,
-        default=False,
+        default=True,
         help="Whether to use half precision training.",
     )
 
     parser.add_argument(
         "--batch-size",
         type=int,
-        default=50,
+        default=400,
     )
 
     parser.add_argument(
@@ -165,7 +166,7 @@ def get_parser():
     parser.add_argument(
         "--tie-weights",
         type=str2bool,
-        default=False,
+        default=True,
         help="""True to share the weights between the input embedding layer and the
         last output linear layer
         """,
@@ -178,6 +179,33 @@ def get_parser():
         help="The seed for random generators intended for reproducibility",
     )
 
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=1e-3,
+    )
+
+    parser.add_argument(
+        "--max-sent-len",
+        type=int,
+        default=200,
+        help="""Maximum number of tokens in a sentence. This is used
+        to adjust batch-size dynamically""",
+    )
+
+    parser.add_argument(
+        "--save-every-n",
+        type=int,
+        default=2000,
+        help="""Save checkpoint after processing this number of batches"
+        periodically. We save checkpoint to exp-dir/ whenever
+        params.batch_idx_train % save_every_n == 0. The checkpoint filename
+        has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+        Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+        end of each epoch where `xxx` is the epoch number counting from 0.
+        """,
+    )
+
     return parser
 
 
@@ -190,16 +218,15 @@ def get_params() -> AttributeDict:
             "sos_id": 1,
             "eos_id": 1,
             "blank_id": 0,
-            "lr": 1e-3,
             "weight_decay": 1e-6,
             "best_train_loss": float("inf"),
             "best_valid_loss": float("inf"),
             "best_train_epoch": -1,
             "best_valid_epoch": -1,
             "batch_idx_train": 0,
-            "log_interval": 200,
+            "log_interval": 100,
             "reset_interval": 2000,
-            "valid_interval": 5000,
+            "valid_interval": 200,
             "env_info": get_env_info(),
         }
     )
@@ -382,6 +409,7 @@ def train_one_epoch(
     valid_dl: torch.utils.data.DataLoader,
     tb_writer: Optional[SummaryWriter] = None,
     world_size: int = 1,
+    rank: int = 0,
 ) -> None:
     """Train the model for one epoch.
 
@@ -430,6 +458,19 @@ def train_one_epoch(
         clip_grad_norm_(model.parameters(), 5.0, 2.0)
         optimizer.step()
 
+        if (
+            params.batch_idx_train > 0
+            and params.batch_idx_train % params.save_every_n == 0
+        ):
+            save_checkpoint_with_global_batch_idx(
+                out_dir=params.exp_dir,
+                global_batch_idx=params.batch_idx_train,
+                model=model,
+                params=params,
+                optimizer=optimizer,
+                rank=rank,
+            )
+
         if batch_idx % params.log_interval == 0:
             # Note: "frames" here means "num_tokens"
             this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
@@ -446,17 +487,13 @@ def train_one_epoch(
                 loss_info.write_summary(
                     tb_writer, "train/current_", params.batch_idx_train
                 )
-                tot_loss.write_summary(
-                    tb_writer, "train/tot_", params.batch_idx_train
-                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
 
                 tb_writer.add_scalar(
                     "train/current_ppl", this_batch_ppl, params.batch_idx_train
                 )
 
-                tb_writer.add_scalar(
-                    "train/tot_ppl", tot_ppl, params.batch_idx_train
-                )
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
 
         if batch_idx > 0 and batch_idx % params.valid_interval == 0:
             logging.info("Computing validation loss")
@@ -535,6 +572,9 @@ def run(rank, world_size, args):
         tie_weights=params.tie_weights,
     )
 
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
     checkpoints = load_checkpoint_if_available(params=params, model=model)
 
     model.to(device)
@@ -581,6 +621,7 @@ def run(rank, world_size, args):
             valid_dl=valid_dl,
             tb_writer=tb_writer,
             world_size=world_size,
+            rank=rank,
         )
 
         save_checkpoint(
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
new file mode 100755
index 000000000..29a2cd7f7
--- /dev/null
+++ b/icefall/shared/convert-k2-to-openfst.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+# Copyright    2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes as input an FST in k2 format and convert it
+to an FST in OpenFST format.
+
+The generated FST is saved into a binary file and its type is
+StdVectorFst.
+
+Usage examples:
+(1) Convert an acceptor
+
+  ./convert-k2-to-openfst.py in.pt binary.fst
+
+(2) Convert a transducer
+
+  ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import kaldifst.utils
+import torch
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--olabels",
+        type=str,
+        default=None,
+        help="""If not empty, the input FST is assumed to be a transducer
+        and we use its attribute specified by "olabels" as the output labels.
+        """,
+    )
+    parser.add_argument(
+        "input_filename",
+        type=str,
+        help="Path to the input FST in k2 format",
+    )
+
+    parser.add_argument(
+        "output_filename",
+        type=str,
+        help="Path to the output FST in OpenFst format",
+    )
+
+    return parser.parse_args()
+
+
+def main():
+    args = get_args()
+    logging.info(f"{vars(args)}")
+
+    input_filename = args.input_filename
+    output_filename = args.output_filename
+    olabels = args.olabels
+
+    if Path(output_filename).is_file():
+        logging.info(f"{output_filename} already exists - skipping")
+        return
+
+    assert Path(input_filename).is_file(), f"{input_filename} does not exist"
+    logging.info(f"Loading {input_filename}")
+    k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
+    if olabels:
+        assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
+
+    p = Path(output_filename).parent
+    if not p.is_dir():
+        logging.info(f"Creating {p}")
+        p.mkdir(parents=True)
+
+    logging.info("Converting (May take some time if the input FST is large)")
+    fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
+    logging.info(f"Saving to {output_filename}")
+    fst.write(output_filename)
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index c2edd823e..7150297d6 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -15,30 +15,43 @@
 # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
 # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
 
-import sys
-import os
-import re
+import argparse
 import io
 import math
-import argparse
+import os
+import re
+import sys
 from collections import Counter, defaultdict
 
-
-parser = argparse.ArgumentParser(description="""
+parser = argparse.ArgumentParser(
+    description="""
     Generate kneser-ney language model as arpa format. By default,
     it will read the corpus from standard input, and output to standard output.
-    """)
-parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
+    """
+)
+parser.add_argument(
+    "-ngram-order",
+    type=int,
+    default=4,
+    choices=[2, 3, 4, 5, 6, 7],
+    help="Order of n-gram",
+)
 parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
-parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
-parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
+parser.add_argument(
+    "-lm", type=str, default=None, help="Path to output arpa file for language models"
+)
+parser.add_argument(
+    "-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level"
+)
 args = parser.parse_args()
 
-default_encoding = "latin-1"  # For encoding-agnostic scripts, we assume byte stream as input.
-                              # Need to be very careful about the use of strip() and split()
-                              # in this case, because there is a latin-1 whitespace character
-                              # (nbsp) which is part of the unicode encoding range.
-                              # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+# For encoding-agnostic scripts, we assume byte stream as input.
+# Need to be very careful about the use of strip() and split()
+# in this case, because there is a latin-1 whitespace character
+# (nbsp) which is part of the unicode encoding range.
+# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
+default_encoding = "latin-1"
+
 strip_chars = " \t\r\n"
 whitespace = re.compile("[ \t]+")
 
@@ -52,7 +65,8 @@ class CountsForHistory:
         # The 'lambda: defaultdict(float)' is an anonymous function taking no
         # arguments that returns a new defaultdict(float).
         self.word_to_count = defaultdict(int)
-        self.word_to_context = defaultdict(set)  # using a set to count the number of unique contexts
+        # using a set to count the number of unique contexts
+        self.word_to_context = defaultdict(set)
         self.word_to_f = dict()  # discounted probability
         self.word_to_bow = dict()  # back-off weight
         self.total_count = 0
@@ -62,10 +76,15 @@ class CountsForHistory:
 
     def __str__(self):
         # e.g. returns ' total=12: 3->4, 4->6, -1->2'
-        return ' total={0}: {1}'.format(
+        return " total={0}: {1}".format(
             str(self.total_count),
-            ', '.join(['{0} -> {1}'.format(word, count)
-                      for word, count in self.word_to_count.items()]))
+            ", ".join(
+                [
+                    "{0} -> {1}".format(word, count)
+                    for word, count in self.word_to_count.items()
+                ]
+            ),
+        )
 
     def add_count(self, predicted_word, context_word, count):
         assert count >= 0
@@ -85,7 +104,7 @@ class NgramCounts:
     # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
     # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
     # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
-    def __init__(self, ngram_order, bos_symbol='', eos_symbol=''):
+    def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
         assert ngram_order >= 2
 
         self.ngram_order = ngram_order
@@ -103,39 +122,47 @@ class NgramCounts:
     # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
     # 1.
     def add_count(self, history, predicted_word, context_word, count):
-        self.counts[len(history)][history].add_count(predicted_word, context_word, count)
+        self.counts[len(history)][history].add_count(
+            predicted_word, context_word, count
+        )
 
     # 'line' is a string containing a sequence of integer word-ids.
     # This function adds the un-smoothed counts from this line of text.
     def add_raw_counts_from_line(self, line):
-        if line == '':
+        if line == "":
             words = [self.bos_symbol, self.eos_symbol]
         else:
             words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
 
         for i in range(len(words)):
-            for n in range(1, self.ngram_order+1):
+            for n in range(1, self.ngram_order + 1):
                 if i + n > len(words):
                     break
-                ngram = words[i: i + n]
+                ngram = words[i : i + n]
                 predicted_word = ngram[-1]
-                history = tuple(ngram[: -1])
+                history = tuple(ngram[:-1])
                 if i == 0 or n == self.ngram_order:
                     context_word = None
                 else:
-                    context_word = words[i-1]
+                    context_word = words[i - 1]
 
                 self.add_count(history, predicted_word, context_word, 1)
 
     def add_raw_counts_from_standard_input(self):
         lines_processed = 0
-        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)  # byte stream as input
+        # byte stream as input
+        infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding)
         for line in infile:
             line = line.strip(strip_chars)
             self.add_raw_counts_from_line(line)
             lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def add_raw_counts_from_file(self, filename):
         lines_processed = 0
@@ -145,7 +172,12 @@ class NgramCounts:
                 self.add_raw_counts_from_line(line)
                 lines_processed += 1
         if lines_processed == 0 or args.verbose > 0:
-            print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
+            print(
+                "make_phone_lm.py: processed {0} lines of input".format(
+                    lines_processed
+                ),
+                file=sys.stderr,
+            )
 
     def cal_discounting_constants(self):
         # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
@@ -153,9 +185,10 @@ class NgramCounts:
         # This constant is used similarly to absolute discounting.
         # Return value: d is a list of floats, where d[N+1] = D_N
 
-        self.d = [0]  # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
-                      # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
-                      # but perhaps this is not the case for some other scenarios.
+        # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
+        # This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
+        # but perhaps this is not the case for some other scenarios.
+        self.d = [0]
         for n in range(1, self.ngram_order):
             this_order_counts = self.counts[n]
             n1 = 0
@@ -165,9 +198,11 @@ class NgramCounts:
                 n1 += stat[1]
                 n2 += stat[2]
             assert n1 + 2 * n2 > 0
-            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))   # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, 
-                                                                # which could happen if the number of symbols is small.
-                                                                # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+
+            # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
+            # which could happen if the number of symbols is small.
+            # Otherwise, zero discounting constant can cause division by zero in computing BOW.
+            self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2))
 
     def cal_f(self):
         # f(a_z) is a probability distribution of word sequence a_z.
@@ -182,7 +217,9 @@ class NgramCounts:
         this_order_counts = self.counts[n]
         for hist, counts_for_hist in this_order_counts.items():
             for w, c in counts_for_hist.word_to_count.items():
-                counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                counts_for_hist.word_to_f[w] = (
+                    max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                )
 
         # lower order N-grams
         for n in range(0, self.ngram_order - 1):
@@ -196,11 +233,17 @@ class NgramCounts:
                 if n_star_star != 0:
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = len(counts_for_hist.word_to_context[w])
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
+                        )
                 else:  # patterns begin with , they do not have "modified count", so use raw count instead
                     for w in counts_for_hist.word_to_count.keys():
                         n_star_z = counts_for_hist.word_to_count[w]
-                        counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
+                        counts_for_hist.word_to_f[w] = (
+                            max((n_star_z - self.d[n]), 0)
+                            * 1.0
+                            / counts_for_hist.total_count
+                        )
 
     def cal_bow(self):
         # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
@@ -240,12 +283,15 @@ class NgramCounts:
                         sum_z1_f_z = 0
                         _ = a_[1:]
                         _counts_for_hist = self.counts[len(_)][_]
-                        for u in a_counts_for_hist.word_to_count.keys():  # Should be careful here: what is Z1
+                        # Should be careful here: what is Z1
+                        for u in a_counts_for_hist.word_to_count.keys():
                             sum_z1_f_z += _counts_for_hist.word_to_f[u]
 
                         if sum_z1_f_z < 1:
                             # assert sum_z1_f_a_z < 1
-                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
+                            counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (
+                                1.0 - sum_z1_f_z
+                            )
                         else:
                             counts_for_hist.word_to_bow[w] = None
 
@@ -259,7 +305,9 @@ class NgramCounts:
                     ngram = " ".join(hist) + " " + w
                     ngram = ngram.strip(strip_chars)
 
-                    res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
+                    res.append(
+                        "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])
+                    )
         res.sort(reverse=True)
         for r in res:
             print(r)
@@ -322,27 +370,40 @@ class NgramCounts:
                     if bow is None:
                         res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
                     else:
-                        res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
+                        res.append(
+                            "{1}\t{0}\t{2}".format(
+                                ngram, math.log(f, 10), math.log(bow, 10)
+                            )
+                        )
         res.sort(reverse=True)
         for r in res:
             print(r)
 
-    def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
+    def print_as_arpa(
+        self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1")
+    ):
         # print as ARPA format.
 
-        print('\\data\\', file=fout)
+        print("\\data\\", file=fout)
         for hist_len in range(self.ngram_order):
             # print the number of n-grams.
-            print('ngram {0}={1}'.format(
-                hist_len + 1,
-                sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
-                file=fout
+            print(
+                "ngram {0}={1}".format(
+                    hist_len + 1,
+                    sum(
+                        [
+                            len(counts_for_hist.word_to_f)
+                            for counts_for_hist in self.counts[hist_len].values()
+                        ]
+                    ),
+                ),
+                file=fout,
             )
 
-        print('', file=fout)
+        print("", file=fout)
 
         for hist_len in range(self.ngram_order):
-            print('\\{0}-grams:'.format(hist_len + 1), file=fout)
+            print("\\{0}-grams:".format(hist_len + 1), file=fout)
 
             this_order_counts = self.counts[hist_len]
             for hist, counts_for_hist in this_order_counts.items():
@@ -354,12 +415,12 @@ class NgramCounts:
                     if prob == 0:  # f() is always 0
                         prob = 1e-99
 
-                    line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
+                    line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram))
                     if bow is not None:
-                        line += '\t{0}'.format('%.7f' % math.log10(bow))
+                        line += "\t{0}".format("%.7f" % math.log10(bow))
                     print(line, file=fout)
-            print('', file=fout)
-        print('\\end\\', file=fout)
+            print("", file=fout)
+        print("\\end\\", file=fout)
 
 
 if __name__ == "__main__":
@@ -379,5 +440,5 @@ if __name__ == "__main__":
     if args.lm is None:
         ngram_counts.print_as_arpa()
     else:
-        with open(args.lm, 'w', encoding=default_encoding) as f:
+        with open(args.lm, "w", encoding=default_encoding) as f:
             ngram_counts.print_as_arpa(fout=f)
diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py
new file mode 100755
index 000000000..b1ebee9ea
--- /dev/null
+++ b/icefall/shared/ngram_entropy_pruning.py
@@ -0,0 +1,630 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright 2021  Johns Hopkins University (Author: Ruizhe Huang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+./ngram_entropy_pruning.py \
+    -threshold 1e-8 \
+    -lm download/lm/4gram.arpa \
+    -write-lm download/lm/4gram_pruned_1e8.arpa
+
+This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
+This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
+in the same way as SRILM.
+"""
+
+
+import argparse
+import gzip
+import logging
+import math
+import re
+from collections import OrderedDict, defaultdict
+from enum import Enum, unique
+from io import StringIO
+
+parser = argparse.ArgumentParser(
+    description="""
+    Prune an n-gram language model based on the relative entropy 
+    between the original and the pruned model, based on Andreas Stolcke's paper.
+    An n-gram entry is removed, if the removal causes (training set) perplexity 
+    of the model to increase by less than threshold relative.
+    
+    The command takes an arpa file and a pruning threshold as input, 
+    and outputs a pruned arpa file.
+    """
+)
+parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
+parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
+parser.add_argument(
+    "-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
+)
+parser.add_argument(
+    "-minorder",
+    type=int,
+    default=1,
+    help="The minorder parameter limits pruning to ngrams of that length and above.",
+)
+parser.add_argument(
+    "-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
+)
+parser.add_argument(
+    "-verbose",
+    type=int,
+    default=2,
+    choices=[0, 1, 2, 3, 4, 5],
+    help="Verbose level, where 0 is most noisy; 5 is most silent",
+)
+args = parser.parse_args()
+
+default_encoding = args.encoding
+logging.basicConfig(
+    format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
+    level=args.verbose * 10,
+)
+
+
+class Context(dict):
+    """
+    This class stores data for a context h.
+    It behaves like a python dict object, except that it has several
+    additional attributes.
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.log_bo = None
+
+
+class Arpa:
+    """
+    This is a class that implement the data structure of an APRA LM.
+    It (as well as some other classes) is modified based on the library
+    by Stefan Fischer:
+    https://github.com/sfischer13/python-arpa
+    """
+
+    UNK = ""
+    SOS = ""
+    EOS = ""
+    FLOAT_NDIGITS = 7
+    base = 10
+
+    @staticmethod
+    def _check_input(my_input):
+        if not my_input:
+            raise ValueError
+        elif isinstance(my_input, tuple):
+            return my_input
+        elif isinstance(my_input, list):
+            return tuple(my_input)
+        elif isinstance(my_input, str):
+            return tuple(my_input.strip().split(" "))
+        else:
+            raise ValueError
+
+    @staticmethod
+    def _check_word(input_word):
+        if not isinstance(input_word, str):
+            raise ValueError
+        if " " in input_word:
+            raise ValueError
+
+    def _replace_unks(self, words):
+        return tuple((w if w in self else self._unk) for w in words)
+
+    def __init__(self, path=None, encoding=None, unk=None):
+        self._counts = OrderedDict()
+        self._ngrams = (
+            OrderedDict()
+        )  # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
+        self._vocabulary = set()
+        if unk is None:
+            self._unk = self.UNK
+
+        if path is not None:
+            self.loadf(path, encoding)
+
+    def __contains__(self, ngram):
+        h = ngram[:-1]  # h is a tuple
+        w = ngram[-1]  # w is a string/word
+        return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
+
+    def contains_word(self, word):
+        self._check_word(word)
+        return word in self._vocabulary
+
+    def add_count(self, order, count):
+        self._counts[order] = count
+        self._ngrams[order - 1] = defaultdict(Context)
+
+    def update_counts(self):
+        for order in range(1, self.order() + 1):
+            count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
+            if count > 0:
+                self._counts[order] = count
+
+    def add_entry(self, ngram, p, bo=None, order=None):
+        # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
+        h = ngram[:-1]  # h is a tuple
+        w = ngram[-1]  # w is a string/word
+
+        # Note that p and bo here are in fact in the log domain (self.base = 10)
+        h_context = self._ngrams[len(h)][h]
+        h_context[w] = p
+        if bo is not None:
+            self._ngrams[len(ngram)][ngram].log_bo = bo
+
+        for word in ngram:
+            self._vocabulary.add(word)
+
+    def counts(self):
+        return sorted(self._counts.items())
+
+    def order(self):
+        return max(self._counts.keys(), default=None)
+
+    def vocabulary(self, sort=True):
+        if sort:
+            return sorted(self._vocabulary)
+        else:
+            return self._vocabulary
+
+    def _entries(self, order):
+        return (
+            self._entry(h, w)
+            for h, wlist in self._ngrams[order - 1].items()
+            for w in wlist
+        )
+
+    def _entry(self, h, w):
+        # return the entry for the ngram (h, w)
+        ngram = h + (w,)
+        log_p = self._ngrams[len(h)][h][w]
+        log_bo = self._log_bo(ngram)
+        if log_bo is not None:
+            return (
+                round(log_p, self.FLOAT_NDIGITS),
+                ngram,
+                round(log_bo, self.FLOAT_NDIGITS),
+            )
+        else:
+            return round(log_p, self.FLOAT_NDIGITS), ngram
+
+    def _log_bo(self, ngram):
+        if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
+            return self._ngrams[len(ngram)][ngram].log_bo
+        else:
+            return None
+
+    def _log_p(self, ngram):
+        h = ngram[:-1]  # h is a tuple
+        w = ngram[-1]  # w is a string/word
+        if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
+            return self._ngrams[len(h)][h][w]
+        else:
+            return None
+
+    def log_p_raw(self, ngram):
+        log_p = self._log_p(ngram)
+        if log_p is not None:
+            return log_p
+        else:
+            if len(ngram) == 1:
+                raise KeyError
+            else:
+                log_bo = self._log_bo(ngram[:-1])
+                if log_bo is None:
+                    log_bo = 0
+                return log_bo + self.log_p_raw(ngram[1:])
+
+    def log_joint_prob(self, sequence):
+        # Compute the joint prob of the sequence based on the chain rule
+        # Note that sequence should be a tuple of strings
+        #
+        # Reference:
+        # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
+
+        log_joint_p = 0
+        seq = sequence
+        while len(seq) > 0:
+            log_joint_p += self.log_p_raw(seq)
+            seq = seq[:-1]
+
+            # If we're computing the marginal probability of the unigram
+            #  context we have to look up  instead since the former
+            # has prob = 0.
+            if len(seq) == 1 and seq[0] == self.SOS:
+                seq = (self.EOS,)
+
+        return log_joint_p
+
+    def set_new_context(self, h):
+        old_context = self._ngrams[len(h)][h]
+        self._ngrams[len(h)][h] = Context()
+        return old_context
+
+    def log_p(self, ngram):
+        words = self._check_input(ngram)
+        if self._unk:
+            words = self._replace_unks(words)
+        return self.log_p_raw(words)
+
+    def log_s(self, sentence, sos=SOS, eos=EOS):
+        words = self._check_input(sentence)
+        if self._unk:
+            words = self._replace_unks(words)
+        if sos:
+            words = (sos,) + words
+        if eos:
+            words = words + (eos,)
+        result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
+        if sos:
+            result = result - self.log_p_raw(words[:1])
+        return result
+
+    def p(self, ngram):
+        return self.base ** self.log_p(ngram)
+
+    def s(self, sentence):
+        return self.base ** self.log_s(sentence)
+
+    def write(self, fp):
+        fp.write("\n\\data\\\n")
+        for order, count in self.counts():
+            fp.write("ngram {}={}\n".format(order, count))
+        fp.write("\n")
+        for order, _ in self.counts():
+            fp.write("\\{}-grams:\n".format(order))
+            for e in self._entries(order):
+                prob = e[0]
+                ngram = " ".join(e[1])
+                if len(e) == 2:
+                    fp.write("{}\t{}\n".format(prob, ngram))
+                elif len(e) == 3:
+                    backoff = e[2]
+                    fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
+                else:
+                    raise ValueError
+            fp.write("\n")
+        fp.write("\\end\\\n")
+
+
+class ArpaParser:
+    """
+    This is a class that implement a parser of an arpa file
+    """
+
+    @unique
+    class State(Enum):
+        DATA = 1
+        COUNT = 2
+        HEADER = 3
+        ENTRY = 4
+
+    re_count = re.compile(r"^ngram (\d+)=(\d+)$")
+    re_header = re.compile(r"^\\(\d+)-grams:$")
+    re_entry = re.compile(
+        "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
+        "\t"
+        "(\\S+( \\S+)*)"
+        "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
+    )
+
+    def _parse(self, fp):
+        self._result = []
+        self._state = self.State.DATA
+        self._tmp_model = None
+        self._tmp_order = None
+        for line in fp:
+            line = line.strip()
+            if self._state == self.State.DATA:
+                self._data(line)
+            elif self._state == self.State.COUNT:
+                self._count(line)
+            elif self._state == self.State.HEADER:
+                self._header(line)
+            elif self._state == self.State.ENTRY:
+                self._entry(line)
+        if self._state != self.State.DATA:
+            raise Exception(line)
+        return self._result
+
+    def _data(self, line):
+        if line == "\\data\\":
+            self._state = self.State.COUNT
+            self._tmp_model = Arpa()
+        else:
+            pass  # skip comment line
+
+    def _count(self, line):
+        match = self.re_count.match(line)
+        if match:
+            order = match.group(1)
+            count = match.group(2)
+            self._tmp_model.add_count(int(order), int(count))
+        elif not line:
+            self._state = self.State.HEADER  # there are no counts
+        else:
+            raise Exception(line)
+
+    def _header(self, line):
+        match = self.re_header.match(line)
+        if match:
+            self._state = self.State.ENTRY
+            self._tmp_order = int(match.group(1))
+        elif line == "\\end\\":
+            self._result.append(self._tmp_model)
+            self._state = self.State.DATA
+            self._tmp_model = None
+            self._tmp_order = None
+        elif not line:
+            pass  # skip empty line
+        else:
+            raise Exception(line)
+
+    def _entry(self, line):
+        match = self.re_entry.match(line)
+        if match:
+            p = self._float_or_int(match.group(1))
+            ngram = tuple(match.group(4).split(" "))
+            bo_match = match.group(7)
+            bo = self._float_or_int(bo_match) if bo_match else None
+            self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
+        elif not line:
+            self._state = self.State.HEADER  # last entry
+        else:
+            raise Exception(line)
+
+    @staticmethod
+    def _float_or_int(s):
+        f = float(s)
+        i = int(f)
+        if str(i) == s:  # don't drop trailing ".0"
+            return i
+        else:
+            return f
+
+    def load(self, fp):
+        """Deserialize fp (a file-like object) to a Python object."""
+        return self._parse(fp)
+
+    def loadf(self, path, encoding=None):
+        """Deserialize path (.arpa, .gz) to a Python object."""
+        path = str(path)
+        if path.endswith(".gz"):
+            with gzip.open(path, mode="rt", encoding=encoding) as f:
+                return self.load(f)
+        else:
+            with open(path, mode="rt", encoding=encoding) as f:
+                return self.load(f)
+
+    def loads(self, s):
+        """Deserialize s (a str) to a Python object."""
+        with StringIO(s) as f:
+            return self.load(f)
+
+    def dump(self, obj, fp):
+        """Serialize obj to fp (a file-like object) in ARPA format."""
+        obj.write(fp)
+
+    def dumpf(self, obj, path, encoding=None):
+        """Serialize obj to path in ARPA format (.arpa, .gz)."""
+        path = str(path)
+        if path.endswith(".gz"):
+            with gzip.open(path, mode="wt", encoding=encoding) as f:
+                return self.dump(obj, f)
+        else:
+            with open(path, mode="wt", encoding=encoding) as f:
+                self.dump(obj, f)
+
+    def dumps(self, obj):
+        """Serialize obj to an ARPA formatted str."""
+        with StringIO() as f:
+            self.dump(obj, f)
+            return f.getvalue()
+
+
+def add_log_p(prev_log_sum, log_p, base):
+    return math.log(base**log_p + base**prev_log_sum, base)
+
+
+def compute_numerator_denominator(lm, h):
+    log_sum_seen_h = -math.inf
+    log_sum_seen_h_lower = -math.inf
+    base = lm.base
+    for w, log_p in lm._ngrams[len(h)][h].items():
+        log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
+
+        ngram = h + (w,)
+        log_p_lower = lm.log_p_raw(ngram[1:])
+        log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
+
+    numerator = 1.0 - base**log_sum_seen_h
+    denominator = 1.0 - base**log_sum_seen_h_lower
+    return numerator, denominator
+
+
+def prune(lm, threshold, minorder):
+    # Reference:
+    # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
+
+    for i in range(
+        lm.order(), max(minorder - 1, 1), -1
+    ):  # i is the order of the ngram (h, w)
+        logging.info("processing %d-grams ..." % i)
+        count_pruned_ngrams = 0
+
+        h_dict = lm._ngrams[i - 1]
+        for h in list(h_dict.keys()):
+            # old backoff weight, BOW(h)
+            log_bow = lm._log_bo(h)
+            if log_bow is None:
+                log_bow = 0
+
+            # Compute numerator and denominator of the backoff weight,
+            # so that we can quickly compute the BOW adjustment due to
+            # leaving out one prob.
+            numerator, denominator = compute_numerator_denominator(lm, h)
+
+            # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
+
+            # Compute the marginal probability of the context, P(h)
+            h_log_p = lm.log_joint_prob(h)
+
+            all_pruned = True
+            pruned_w_set = set()
+
+            for w, log_p in h_dict[h].items():
+                ngram = h + (w,)
+
+                # lower-order estimate for ngramProb, P(w|h')
+                backoff_prob = lm.log_p_raw(ngram[1:])
+
+                # Compute BOW after removing ngram, BOW'(h)
+                new_log_bow = math.log(
+                    numerator + lm.base**log_p, lm.base
+                ) - math.log(denominator + lm.base**backoff_prob, lm.base)
+
+                # Compute change in entropy due to removal of ngram
+                delta_prob = backoff_prob + new_log_bow - log_p
+                delta_entropy = -(lm.base**h_log_p) * (
+                    (lm.base**log_p) * delta_prob
+                    + numerator * (new_log_bow - log_bow)
+                )
+
+                # compute relative change in model (training set) perplexity
+                perp_change = lm.base**delta_entropy - 1.0
+
+                pruned = threshold > 0 and perp_change < threshold
+
+                # Make sure we don't prune ngrams whose backoff nodes are needed
+                if (
+                    pruned
+                    and len(ngram) in lm._ngrams
+                    and len(lm._ngrams[len(ngram)][ngram]) > 0
+                ):
+                    pruned = False
+
+                logging.debug(
+                    "CONTEXT "
+                    + str(h)
+                    + " WORD "
+                    + w
+                    + " CONTEXTPROB %f " % h_log_p
+                    + " OLDPROB %f " % log_p
+                    + " NEWPROB %f " % (backoff_prob + new_log_bow)
+                    + " DELTA-H %f " % delta_entropy
+                    + " DELTA-LOGP %f " % delta_prob
+                    + " PPL-CHANGE %f " % perp_change
+                    + " PRUNED "
+                    + str(pruned)
+                )
+
+                if pruned:
+                    pruned_w_set.add(w)
+                    count_pruned_ngrams += 1
+                else:
+                    all_pruned = False
+
+            # If we removed all ngrams for this context we can
+            # remove the context itself, but only if the present
+            # context is not a prefix to a longer one.
+            if all_pruned and len(pruned_w_set) == len(h_dict[h]):
+                del h_dict[
+                    h
+                ]  # this context h is no longer needed, as its ngram prob is stored at its own context h'
+            elif len(pruned_w_set) > 0:
+                # The pruning for this context h is actually done here
+                old_context = lm.set_new_context(h)
+
+                for w, p_w in old_context.items():
+                    if w not in pruned_w_set:
+                        lm.add_entry(
+                            h + (w,), p_w
+                        )  # the entry hw is stored at the context h
+
+                # We need to recompute the back-off weight, but
+                # this can only be done after completing the pruning
+                # of the lower-order ngrams.
+                # Reference:
+                # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
+
+        logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
+
+    # recompute backoff weights
+    for i in range(
+        max(minorder - 1, 1) + 1, lm.order() + 1
+    ):  # be careful of this order: from low- to high-order
+        for h in lm._ngrams[i - 1]:
+            numerator, denominator = compute_numerator_denominator(lm, h)
+            new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
+            lm._ngrams[len(h)][h].log_bo = new_log_bow
+
+    # update counts
+    lm.update_counts()
+
+    return
+
+
+def check_h_is_valid(lm, h):
+    sum_under_h = sum(
+        [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
+    )
+    if abs(sum_under_h - 1.0) > 1e-6:
+        logging.info("warning: %s %f" % (str(h), sum_under_h))
+        return False
+    else:
+        return True
+
+
+def validate_lm(lm):
+    # sanity check if the conditional probability sums to one under each context h
+    for i in range(lm.order(), 0, -1):  # i is the order of the ngram (h, w)
+        logging.info("validating %d-grams ..." % i)
+        h_dict = lm._ngrams[i - 1]
+        for h in h_dict.keys():
+            check_h_is_valid(lm, h)
+
+
+def compare_two_apras(path1, path2):
+    pass
+
+
+if __name__ == "__main__":
+    # load an arpa file
+    logging.info("Loading the arpa file from %s" % args.lm)
+    parser = ArpaParser()
+    models = parser.loadf(args.lm, encoding=default_encoding)
+    lm = models[0]  # ARPA files may contain several models.
+    logging.info("Stats before pruning:")
+    for i, cnt in lm.counts():
+        logging.info("ngram %d=%d" % (i, cnt))
+
+    # prune it, the language model will be modified in-place
+    logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
+    prune(lm, args.threshold, args.minorder)
+
+    # validate_lm(lm)
+
+    # write the arpa language model to a file
+    logging.info("Stats after pruning:")
+    for i, cnt in lm.counts():
+        logging.info("ngram %d=%d" % (i, cnt))
+    logging.info("Saving the pruned arpa file to %s" % args.write_lm)
+    parser.dumpf(lm, args.write_lm, encoding=default_encoding)
+    logging.info("Done.")
diff --git a/icefall/transformer_lm/__init__.py b/icefall/transformer_lm/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/icefall/transformer_lm/attention.py b/icefall/transformer_lm/attention.py
new file mode 100644
index 000000000..5ce83b15e
--- /dev/null
+++ b/icefall/transformer_lm/attention.py
@@ -0,0 +1,510 @@
+# Copyright (c)  2021  University of Chinese Academy of Sciences (author: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import List, Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+
+from icefall.transformer_lm.scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledConv2d,
+    ScaledLinear,
+)
+from icefall.utils import is_jit_tracing
+
+
+class RelPositionMultiheadAttention(nn.Module):
+    r"""Multi-Head Attention layer with relative position encoding
+
+    See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+    Args:
+        embed_dim: total dimension of the model.
+        num_heads: parallel attention heads.
+        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+
+    Examples::
+
+        >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
+        >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+    ) -> None:
+        super(RelPositionMultiheadAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        assert (
+            self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
+        self.out_proj = ScaledLinear(
+            embed_dim, embed_dim, bias=True, initial_scale=0.25
+        )
+
+        # linear transformation for positional encoding.
+        self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+        self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
+        self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
+        self._reset_parameters()
+
+    def _pos_bias_u(self):
+        return self.pos_bias_u * self.pos_bias_u_scale.exp()
+
+    def _pos_bias_v(self):
+        return self.pos_bias_v * self.pos_bias_v_scale.exp()
+
+    def _reset_parameters(self) -> None:
+        nn.init.normal_(self.pos_bias_u, std=0.01)
+        nn.init.normal_(self.pos_bias_v, std=0.01)
+
+    def forward(
+        self,
+        query: Tensor,
+        key: Tensor,
+        value: Tensor,
+        pos_emb: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[Tensor] = None,
+        left_context: int = 0,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        r"""
+        Args:
+            query, key, value: map a query and a set of key-value pairs to an output.
+            pos_emb: Positional embedding tensor
+            key_padding_mask: if provided, specified padding elements in the key will
+                be ignored by the attention. When given a binary mask and a value is True,
+                the corresponding value on the attention layer will be ignored. When given
+                a byte mask and a value is non-zero, the corresponding value on the attention
+                layer will be ignored
+            need_weights: output attn_output_weights.
+            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+            left_context (int): left context (in frames) used during streaming decoding.
+                this is used only in real streaming decoding, in other circumstances,
+                it MUST be 0.
+
+        Shape:
+            - Inputs:
+            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the position
+            with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+            - Outputs:
+            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+        return self.multi_head_attention_forward(
+            query,
+            key,
+            value,
+            pos_emb,
+            self.embed_dim,
+            self.num_heads,
+            self.in_proj.get_weight(),
+            self.in_proj.get_bias(),
+            self.dropout,
+            self.out_proj.get_weight(),
+            self.out_proj.get_bias(),
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+            left_context=left_context,
+        )
+
+    def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
+        """Compute relative positional encoding.
+
+        Args:
+            x: Input tensor (batch, head, time1, 2*time1-1+left_context).
+                time1 means the length of query vector.
+            left_context (int): left context (in frames) used during streaming decoding.
+                this is used only in real streaming decoding, in other circumstances,
+                it MUST be 0.
+
+        Returns:
+            Tensor: tensor of shape (batch, head, time1, time2)
+          (note: time2 has the same value as time1, but it is for
+          the key, while time1 is for the query).
+        """
+        (batch_size, num_heads, time1, n) = x.shape
+
+        time2 = time1 + left_context
+        if not is_jit_tracing():
+            assert (
+                n == left_context + 2 * time1 - 1
+            ), f"{n} == {left_context} + 2 * {time1} - 1"
+
+        if is_jit_tracing():
+            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+            cols = torch.arange(time2)
+            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+            indexes = rows + cols
+
+            x = x.reshape(-1, n)
+            x = torch.gather(x, dim=1, index=indexes)
+            x = x.reshape(batch_size, num_heads, time1, time2)
+            return x
+        else:
+            # Note: TorchScript requires explicit arg for stride()
+            batch_stride = x.stride(0)
+            head_stride = x.stride(1)
+            time1_stride = x.stride(2)
+            n_stride = x.stride(3)
+            return x.as_strided(
+                (batch_size, num_heads, time1, time2),
+                (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+                storage_offset=n_stride * (time1 - 1),
+            )
+
+    def multi_head_attention_forward(
+        self,
+        query: Tensor,
+        key: Tensor,
+        value: Tensor,
+        pos_emb: Tensor,
+        embed_dim_to_check: int,
+        num_heads: int,
+        in_proj_weight: Tensor,
+        in_proj_bias: Tensor,
+        dropout_p: float,
+        out_proj_weight: Tensor,
+        out_proj_bias: Tensor,
+        training: bool = True,
+        key_padding_mask: Optional[Tensor] = None,
+        need_weights: bool = False,
+        attn_mask: Optional[Tensor] = None,
+        left_context: int = 0,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+        r"""
+        Args:
+            query, key, value: map a query and a set of key-value pairs to an output.
+            pos_emb: Positional embedding tensor
+            embed_dim_to_check: total dimension of the model.
+            num_heads: parallel attention heads.
+            in_proj_weight, in_proj_bias: input projection weight and bias.
+            dropout_p: probability of an element to be zeroed.
+            out_proj_weight, out_proj_bias: the output projection weight and bias.
+            training: apply dropout if is ``True``.
+            key_padding_mask: if provided, specified padding elements in the key will
+                be ignored by the attention. This is an binary mask. When the value is True,
+                the corresponding value on the attention layer will be filled with -inf.
+            need_weights: output attn_output_weights.
+            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+            left_context (int): left context (in frames) used during streaming decoding.
+                this is used only in real streaming decoding, in other circumstances,
+                it MUST be 0.
+
+        Shape:
+            Inputs:
+            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+            the embedding dimension.
+            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+            the embedding dimension.
+            - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+            length, N is the batch size, E is the embedding dimension.
+            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+            If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+            will be unchanged. If a BoolTensor is provided, the positions with the
+            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+            3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+            S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+            is provided, it will be added to the attention weight.
+
+            Outputs:
+            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+            E is the embedding dimension.
+            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+            L is the target sequence length, S is the source sequence length.
+        """
+
+        tgt_len, bsz, embed_dim = query.size()
+        if not is_jit_tracing():
+            assert embed_dim == embed_dim_to_check
+            assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+        head_dim = embed_dim // num_heads
+        if not is_jit_tracing():
+            assert (
+                head_dim * num_heads == embed_dim
+            ), "embed_dim must be divisible by num_heads"
+
+        scaling = float(head_dim) ** -0.5
+
+        if torch.equal(query, key) and torch.equal(key, value):
+            # self-attention
+            q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+                3, dim=-1
+            )
+
+        elif torch.equal(key, value):
+            # encoder-decoder attention
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+        else:
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = 0
+            _end = embed_dim
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            q = nn.functional.linear(query, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim
+            _end = embed_dim * 2
+            _w = in_proj_weight[_start:_end, :]
+            if _b is not None:
+                _b = _b[_start:_end]
+            k = nn.functional.linear(key, _w, _b)
+
+            # This is inline in_proj function with in_proj_weight and in_proj_bias
+            _b = in_proj_bias
+            _start = embed_dim * 2
+            _end = None
+            _w = in_proj_weight[_start:, :]
+            if _b is not None:
+                _b = _b[_start:]
+            v = nn.functional.linear(value, _w, _b)
+
+        if attn_mask is not None:
+            assert (
+                attn_mask.dtype == torch.float32
+                or attn_mask.dtype == torch.float64
+                or attn_mask.dtype == torch.float16
+                or attn_mask.dtype == torch.uint8
+                or attn_mask.dtype == torch.bool
+            ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+                attn_mask.dtype
+            )
+            if attn_mask.dtype == torch.uint8:
+                warnings.warn(
+                    "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+                )
+                attn_mask = attn_mask.to(torch.bool)
+
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
+            elif attn_mask.dim() == 3:
+                if list(attn_mask.size()) != [
+                    bsz * num_heads,
+                    query.size(0),
+                    key.size(0),
+                ]:
+                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
+            else:
+                raise RuntimeError(
+                    "attn_mask's dimension {} is not supported".format(attn_mask.dim())
+                )
+            # attn_mask's dim is 3 now.
+
+        # convert ByteTensor key_padding_mask to bool
+        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+            warnings.warn(
+                "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+            )
+            key_padding_mask = key_padding_mask.to(torch.bool)
+
+        q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
+        k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+        v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+        src_len = k.size(0)
+
+        if key_padding_mask is not None and not is_jit_tracing():
+            assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+                key_padding_mask.size(0), bsz
+            )
+            assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+                key_padding_mask.size(1), src_len
+            )
+
+        q = q.transpose(0, 1)  # (batch, time1, head, d_k)
+
+        pos_emb_bsz = pos_emb.size(0)
+        if not is_jit_tracing():
+            assert pos_emb_bsz in (1, bsz)  # actually it is 1
+
+        p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+        # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
+        p = p.permute(0, 2, 3, 1)
+
+        q_with_bias_u = (q + self._pos_bias_u()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        q_with_bias_v = (q + self._pos_bias_v()).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+        k = k.permute(1, 2, 3, 0)  # (batch, head, d_k, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k)  # (batch, head, time1, time2)
+
+        # compute matrix b and matrix d
+        matrix_bd = torch.matmul(q_with_bias_v, p)  # (batch, head, time1, 2*time1-1)
+        matrix_bd = self.rel_shift(matrix_bd, left_context)
+
+        attn_output_weights = matrix_ac + matrix_bd  # (batch, head, time1, time2)
+
+        attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
+
+        if not is_jit_tracing():
+            assert list(attn_output_weights.size()) == [
+                bsz * num_heads,
+                tgt_len,
+                src_len,
+            ]
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+            else:
+                attn_output_weights += attn_mask
+
+        if key_padding_mask is not None:
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            attn_output_weights = attn_output_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                float("-inf"),
+            )
+            attn_output_weights = attn_output_weights.view(
+                bsz * num_heads, tgt_len, src_len
+            )
+
+        attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+
+        # If we are using dynamic_chunk_training and setting a limited
+        # num_left_chunks, the attention may only see the padding values which
+        # will also be masked out by `key_padding_mask`, at this circumstances,
+        # the whole column of `attn_output_weights` will be `-inf`
+        # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
+        # positions to avoid invalid loss value below.
+        if (
+            attn_mask is not None
+            and attn_mask.dtype == torch.bool
+            and key_padding_mask is not None
+        ):
+            if attn_mask.size(0) != 1:
+                attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
+                combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
+            else:
+                # attn_mask.shape == (1, tgt_len, src_len)
+                combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
+                    1
+                ).unsqueeze(2)
+
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
+            attn_output_weights = attn_output_weights.view(
+                bsz * num_heads, tgt_len, src_len
+            )
+
+        attn_output_weights = nn.functional.dropout(
+            attn_output_weights, p=dropout_p, training=training
+        )
+
+        attn_output = torch.bmm(attn_output_weights, v)
+
+        if not is_jit_tracing():
+            assert list(attn_output.size()) == [
+                bsz * num_heads,
+                tgt_len,
+                head_dim,
+            ]
+
+        attn_output = (
+            attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        )
+        attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
+
+        if need_weights:
+            # average attention weights over heads
+            attn_output_weights = attn_output_weights.view(
+                bsz, num_heads, tgt_len, src_len
+            )
+            return attn_output, attn_output_weights.sum(dim=1) / num_heads
+        else:
+            return attn_output, None
diff --git a/icefall/transformer_lm/compute_perplexity.py b/icefall/transformer_lm/compute_perplexity.py
new file mode 100644
index 000000000..72d7c477b
--- /dev/null
+++ b/icefall/transformer_lm/compute_perplexity.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Fangjun Kuang
+#                                                  Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+from pathlib import Path
+
+import torch
+from dataset import get_dataloader
+from train import get_params
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.transformer_lm.model import TransformerLM
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=7,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=1,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="transformer_lm/exp_full_libri_16layer_maxlen200_8gpu",
+    )
+
+    parser.add_argument(
+        "--lm-data",
+        type=str,
+        help="Path to the LM test data for computing perplexity",
+        default="transformer_lm/libri_lm_training_bpe500/sorted_lm_data-test.pt",
+    )
+
+    parser.add_argument(
+        "--vocab-size",
+        type=int,
+        default=500,
+        help="Vocabulary size of the model",
+    )
+
+    parser.add_argument(
+        "--num-layers",
+        type=int,
+        default=16,
+        help="Number of RNN layers the model",
+    )
+
+    parser.add_argument(
+        "--tie-weights",
+        type=str2bool,
+        default=False,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=50,
+        help="Number of RNN layers the model",
+    )
+
+    parser.add_argument(
+        "--max-sent-len",
+        type=int,
+        default=100,
+        help="Number of RNN layers the model",
+    )
+
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+    args.lm_data = Path(args.lm_data)
+
+    params = get_params()
+    params.update(vars(args))
+
+    setup_logger(f"{params.exp_dir}/log-ppl/")
+    logging.info("Computing perplexity started")
+    logging.info(params)
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"Device: {device}")
+
+    logging.info("About to create model")
+    model = TransformerLM(
+        vocab_size=params.vocab_size,
+        d_model=params.encoder_dim,
+        embedding_dim=params.embedding_dim,
+        dim_feedforward=params.dim_feedforward,
+        nhead=params.nhead,
+        num_layers=params.num_layers,
+        tie_weights=params.tie_weights,
+        params=params,
+    )
+
+    if params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        model.to(device)
+    else:
+        start = params.epoch - params.avg + 1
+        filenames = []
+        for i in range(start, params.epoch + 1):
+            if start >= 0:
+                filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+        logging.info(f"averaging {filenames}")
+        model.to(device)
+        model.load_state_dict(average_checkpoints(filenames, device=device))
+
+    model.eval()
+    num_param = sum([p.numel() for p in model.parameters()])
+    num_param_requires_grad = sum(
+        [p.numel() for p in model.parameters() if p.requires_grad]
+    )
+
+    logging.info(f"Number of model parameters: {num_param}")
+    logging.info(
+        f"Number of model parameters (requires_grad): "
+        f"{num_param_requires_grad} "
+        f"({num_param_requires_grad/num_param_requires_grad*100}%)"
+    )
+
+    logging.info(f"Loading LM test data from {params.lm_data}")
+    test_dl = get_dataloader(
+        filename=params.lm_data,
+        is_distributed=False,
+        params=params,
+    )
+
+    tot_loss = 0.0
+    num_tokens = 0
+    num_sentences = 0
+    for batch_idx, batch in enumerate(test_dl):
+        x, y, sentence_lengths = batch
+        x = x.to(device)
+        y = y.to(device)
+        sentence_lengths = sentence_lengths.to(device)
+
+        nll = model(x, y, sentence_lengths)
+        loss = nll.sum().cpu().item()
+
+        tot_loss += loss
+        num_tokens += sentence_lengths.sum().cpu().item()
+        num_sentences += x.size(0)
+
+    ppl = math.exp(tot_loss / num_tokens)
+    logging.info(
+        f"total nll: {tot_loss}, num tokens: {num_tokens}, "
+        f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/icefall/transformer_lm/dataset.py b/icefall/transformer_lm/dataset.py
new file mode 120000
index 000000000..5792a6cf0
--- /dev/null
+++ b/icefall/transformer_lm/dataset.py
@@ -0,0 +1 @@
+../rnn_lm/dataset.py
\ No newline at end of file
diff --git a/icefall/transformer_lm/encoder.py b/icefall/transformer_lm/encoder.py
new file mode 100644
index 000000000..4357b83d7
--- /dev/null
+++ b/icefall/transformer_lm/encoder.py
@@ -0,0 +1,329 @@
+# Copyright (c)  2021  Xiaomi Corporation (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from icefall.transformer_lm.attention import RelPositionMultiheadAttention
+from icefall.transformer_lm.scaling import (
+    ActivationBalancer,
+    BasicNorm,
+    DoubleSwish,
+    ScaledConv1d,
+    ScaledConv2d,
+    ScaledLinear,
+)
+from icefall.utils import is_jit_tracing, make_pad_mask
+
+
+class Transformer(torch.nn.Module):
+    """_summary_
+
+    Args:
+        input_dim (int): Input feature dimension
+        d_mode (int): The dimension of the transformer
+        dim_feedforward (int ): The dimension of the ffw module
+        nhead (int): The number of attention heads
+        dropout_rate (float): dropout rate
+        att_dropout (float): dropout rate in attention module
+    """
+
+    def __init__(
+        self,
+        input_dim: int,
+        d_model: int,
+        dim_feedforward: int,
+        nhead: int = 4,
+        num_layers: int = 6,
+        dropout_rate: float = 0.1,
+        att_dropout: float = 0.0,
+    ):
+        super().__init__()
+
+        self.encoder_layers = num_layers
+        self.d_model = d_model
+
+        self.embed = ScaledLinear(input_dim, d_model)
+        self.norm_before = BasicNorm(d_model, learn_eps=False)
+
+        self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate)
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model=d_model,
+            dim_feedforward=dim_feedforward,
+            nhead=nhead,
+            dropout_rate=dropout_rate,
+        )
+
+        self.encoder = TransformerEncoder(encoder_layer, num_layers)
+
+    def _create_attention_mask(self, x_lens: torch.Tensor):
+        # create a 2D attention mask to mask out
+        # the upper right half of the attention matrix
+        max_len = max(x_lens)
+        ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool)
+        return torch.triu(ones, diagonal=1)
+
+    def forward(
+        self, x: torch.Tensor, x_lens: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Transformer forward
+
+        Args:
+            x (torch.Tensor): Input tensor (B,T,input_dim)
+            x_lens (torch.Tensor): The length of input tensors before padding (B,)
+
+        Returns:
+            Return a tuple of 2 tensors:
+            - x: output feature of the transformer (B,T,d_model)
+            - x_lens: output feature lens of the transformer
+        """
+
+        attention_mask = self._create_attention_mask(x_lens)
+        src_key_padding_mask = make_pad_mask(x_lens)
+
+        x = self.norm_before(self.embed(x))
+
+        x, pos_emb = self.encoder_pos(x)
+        x = x.permute(1, 0, 2)
+
+        x = self.encoder(
+            x,
+            pos_emb,
+            mask=attention_mask,  # pass the attention mast
+            src_key_padding_mask=src_key_padding_mask,
+        )  # (T, N, C)
+
+        x = x.permute(1, 0, 2)  # (T, N, C) ->(N, T, C)
+        return x, x_lens
+
+
+class TransformerEncoder(torch.nn.Module):
+    def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None:
+        """TransformerEncoder is a stack of N encoder layers
+
+        Args:
+            encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer()
+            num_layers (int): Number of layers to be stacked
+        """
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """_summary_
+
+        Args:
+            src: the sequence to the encoder (required).
+            pos_emb: Positional embedding tensor (required).
+            mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+
+        Returns:
+            output: transformer encoded features
+        """
+        output = src
+
+        for layer_index, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                pos_emb,
+                src_key_padding_mask=src_key_padding_mask,
+                src_mask=mask,
+            )
+
+        return output
+
+
+class TransformerEncoderLayer(torch.nn.Module):
+    def __init__(
+        self,
+        d_model: int,
+        dim_feedforward: int,
+        nhead: int,
+        dropout_rate: float,
+    ):
+        """TransformerEncoderLayer is made up of self-attn and feedforward module
+
+        Args:
+            d_model (int): The model size
+            dim_feedforward (int): Dimension of ffw module
+            nhead (int): Number of heads
+            dropout_rate (float): Dropout rate
+        """
+        super().__init__()
+
+        self.d_model = d_model
+
+        self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
+        self.feed_forward = nn.Sequential(
+            ScaledLinear(d_model, dim_feedforward),
+            ActivationBalancer(channel_dim=-1),
+            DoubleSwish(),
+            nn.Dropout(dropout_rate),
+            ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
+        )
+
+        self.norm_final = BasicNorm(d_model)
+
+        self.balancer = ActivationBalancer(
+            channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
+        )
+
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def forward(
+        self,
+        src: torch.Tensor,
+        pos_emb: torch.Tensor,
+        src_key_padding_mask: Optional[torch.Tensor] = None,
+        src_mask: Optional[torch.Tensor] = None,
+        cache=None,
+    ):
+        """
+        Pass the input through the encoder layer.
+
+        Args:
+            src: the sequence to the encoder layer (required).
+            pos_emb: Positional embedding tensor (required).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+            src_mask: the mask for the src sequence (optional).
+        """
+        src_orig = src
+
+        src_att = self.self_attn(
+            src,
+            src,
+            src,
+            pos_emb=pos_emb,
+            attn_mask=src_mask,
+            key_padding_mask=src_key_padding_mask,
+        )[0]
+
+        src = src + self.dropout(src_att)
+
+        # feed forward module
+        src = src + self.dropout(self.feed_forward(src))
+
+        src = self.norm_final(self.balancer(src))
+
+        return src
+
+
+class RelPositionalEncoding(torch.nn.Module):
+    """Relative positional encoding module.
+
+    See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+    Args:
+        d_model: Embedding dimension.
+        dropout_rate: Dropout rate.
+        max_len: Maximum input length.
+
+    """
+
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+        """Construct an PositionalEncoding object."""
+        super(RelPositionalEncoding, self).__init__()
+        if is_jit_tracing():
+            # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
+            # It assumes that the maximum input won't have more than
+            # 10k frames.
+            #
+            # TODO(fangjun): Use torch.jit.script() for this module
+            max_len = 10000
+
+        self.d_model = d_model
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+    def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
+        """Reset the positional encodings."""
+        x_size_1 = x.size(1) + left_context
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(1) >= x_size_1 * 2 - 1:
+                # Note: TorchScript doesn't implement operator== for torch.Device
+                if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` means to the position of query vector and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+        """Add positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+            left_context (int): left context (in frames) used during streaming decoding.
+                this is used only in real streaming decoding, in other circumstances,
+                it MUST be 0.
+
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+            torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+
+        """
+        self.extend_pe(x, left_context)
+        x_size_1 = x.size(1) + left_context
+        pos_emb = self.pe[
+            :,
+            self.pe.size(1) // 2
+            - x_size_1
+            + 1 : self.pe.size(1) // 2  # noqa E203
+            + x.size(1),
+        ]
+        return self.dropout(x), self.dropout(pos_emb)
diff --git a/icefall/transformer_lm/export.py b/icefall/transformer_lm/export.py
new file mode 100644
index 000000000..c08982e37
--- /dev/null
+++ b/icefall/transformer_lm/export.py
@@ -0,0 +1,186 @@
+#!/usr/bin/env python3
+# Copyright (c)  2022  Xiaomi Corporation (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from model import TransformerLM
+
+from icefall.checkpoint import load_checkpoint
+from icefall.utils import AttributeDict, load_averaged_model, str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=11,
+        help="It specifies the checkpoint to use for decoding."
+        "Note: Epoch counts from 0.",
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=5,
+        help="Number of checkpoints to average. Automatically select "
+        "consecutive checkpoints before the checkpoint specified by "
+        "'--epoch'. ",
+    )
+
+    parser.add_argument(
+        "--vocab-size",
+        type=int,
+        default=500,
+        help="Vocabulary size of the model",
+    )
+
+    parser.add_argument(
+        "--embedding-dim",
+        type=int,
+        default=768,
+        help="Embedding dim of the model",
+    )
+
+    parser.add_argument(
+        "--encoder-dim",
+        type=int,
+        default=768,
+        help="Encoder dim of the model",
+    )
+
+    parser.add_argument(
+        "--dim_feedforward",
+        type=int,
+        default=2048,
+        help="Hidden dim of the model",
+    )
+
+    parser.add_argument(
+        "--nhead",
+        type=int,
+        default=8,
+        help="Number of attention heads",
+    )
+
+    parser.add_argument(
+        "--num-layers",
+        type=int,
+        default=16,
+        help="Number of Transformer layers",
+    )
+
+    parser.add_argument(
+        "--tie-weights",
+        type=str2bool,
+        default=True,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="rnn_lm/exp",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--jit",
+        type=str2bool,
+        default=True,
+        help="""True to save a model after applying torch.jit.script.
+        """,
+    )
+
+    return parser
+
+
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = AttributeDict({})
+    params.update(vars(args))
+
+    logging.info(params)
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    logging.info("About to create model")
+    model = TransformerLM(
+        vocab_size=params.vocab_size,
+        d_model=params.encoder_dim,
+        embedding_dim=params.embedding_dim,
+        dim_feedforward=params.dim_feedforward,
+        nhead=params.nhead,
+        num_layers=params.num_layers,
+        tie_weights=params.tie_weights,
+        params=params,
+    )
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    model.to(device)
+
+    if params.avg == 1:
+        load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+    else:
+        model = load_averaged_model(
+            params.exp_dir, model, params.epoch, params.avg, device
+        )
+
+    model.to("cpu")
+    model.eval()
+
+    if params.jit:
+        logging.info("Using torch.jit.script")
+        model = torch.jit.script(model)
+        filename = params.exp_dir / "cpu_jit.pt"
+        model.save(str(filename))
+        logging.info(f"Saved to {filename}")
+    else:
+        logging.info("Not using torch.jit.script")
+        # Save it using a format so that it can be loaded
+        # by :func:`load_checkpoint`
+        filename = params.exp_dir / "pretrained.pt"
+        torch.save({"model": model.state_dict()}, str(filename))
+        logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py
new file mode 100644
index 000000000..79dda3168
--- /dev/null
+++ b/icefall/transformer_lm/model.py
@@ -0,0 +1,115 @@
+# Copyright (c)  2022  Xiaomi Corporation (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from icefall.transformer_lm.encoder import Transformer
+from icefall.utils import AttributeDict, add_eos, add_sos, make_pad_mask
+
+
+class TransformerLM(torch.nn.Module):
+    def __init__(
+        self,
+        vocab_size: int,
+        embedding_dim: int,
+        d_model: int,
+        dim_feedforward: int,
+        nhead: int = 8,
+        num_layers: int = 16,
+        tie_weights: bool = True,
+        dropout: float = 0.1,
+        emb_dropout_rate: float = 0.0,
+        params: AttributeDict = None,
+    ):
+        super().__init__()
+
+        self.vocab_size = vocab_size
+        self.params = params
+
+        self.input_embedding = torch.nn.Embedding(
+            num_embeddings=vocab_size,
+            embedding_dim=embedding_dim,
+        )
+
+        self.encoder = Transformer(
+            input_dim=embedding_dim,
+            d_model=d_model,
+            dim_feedforward=dim_feedforward,
+            nhead=nhead,
+            num_layers=num_layers,
+            dropout_rate=dropout,
+        )
+
+        self.output_linear = torch.nn.Linear(
+            in_features=d_model, out_features=vocab_size
+        )
+        if tie_weights:
+            logging.info("Tying weights")
+            assert d_model == embedding_dim, (d_model, embedding_dim)
+            self.output_linear.weight = self.input_embedding.weight
+        else:
+            logging.info("Not tying weights")
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        y: torch.Tensor,
+        x_lens: torch.Tensor,
+        return_logits: bool = False,
+    ):
+        """Forward transformer language model
+
+        Args:
+            x (torch.Tensor): Input tokens (B,L)
+            y (torch.Tensor): Output tokens (with EOS appended) (B,L)
+            x_lens (torch.Tensor): Length of input tokens before padding (B,)
+            return_logits (bool, optional): Return logits instead of NLL
+
+        """
+
+        x = self.input_embedding(x)
+
+        x, x_lens = self.encoder(x, x_lens)
+
+        logits = self.output_linear(x)
+
+        if return_logits:
+            return logits
+
+        nll_loss = F.cross_entropy(
+            logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
+        )
+
+        mask = make_pad_mask(x_lens).reshape(-1)
+        nll_loss.masked_fill_(mask, 0)
+
+        return nll_loss
+
+    def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
+
+        bs = x.size(0)
+
+        state = None
+        logits = self.forward(x, x, x_lens, return_logits=True)
+        index = torch.arange(bs)
+
+        last_logits = logits[index, x_lens - 1, :]
+
+        return last_logits.log_softmax(-1), state
diff --git a/icefall/transformer_lm/scaling.py b/icefall/transformer_lm/scaling.py
new file mode 120000
index 000000000..0876c0704
--- /dev/null
+++ b/icefall/transformer_lm/scaling.py
@@ -0,0 +1 @@
+../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py
new file mode 100644
index 000000000..c36abfcdf
--- /dev/null
+++ b/icefall/transformer_lm/train.py
@@ -0,0 +1,609 @@
+#!/usr/bin/env python3
+# Copyright    2021  Xiaomi Corp.        (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Usage:
+    ./transformer_lm/train.py \
+        --start-epoch 0 \
+        --world-size 2 \
+        --num-epochs 1 \
+        --use-fp16 0 \
+        --num-layers 12 \
+        --batch-size 400
+
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import torch.optim as optim
+from dataset import get_dataloader
+from lhotse.utils import fix_random_seed
+from model import TransformerLM
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--world-size",
+        type=int,
+        default=1,
+        help="Number of GPUs for DDP training.",
+    )
+
+    parser.add_argument(
+        "--master-port",
+        type=int,
+        default=12354,
+        help="Master port to use for DDP training.",
+    )
+
+    parser.add_argument(
+        "--tensorboard",
+        type=str2bool,
+        default=True,
+        help="Should various information be logged in tensorboard.",
+    )
+
+    parser.add_argument(
+        "--num-epochs",
+        type=int,
+        default=30,
+        help="Number of epochs to train.",
+    )
+
+    parser.add_argument(
+        "--start-epoch",
+        type=int,
+        default=0,
+        help="""Resume training from from this epoch.
+        If it is positive, it will load checkpoint from
+        exp_dir/epoch-{start_epoch-1}.pt
+        """,
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="transformer_lm/exp",
+        help="""The experiment dir.
+        It specifies the directory where all training related
+        files, e.g., checkpoints, logs, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--use-fp16",
+        type=str2bool,
+        default=True,
+        help="Whether to use half precision training.",
+    )
+
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=400,
+    )
+
+    parser.add_argument(
+        "--lm-data",
+        type=str,
+        default="data/lm_training_bpe_500/sorted_lm_data.pt",
+        help="LM training data",
+    )
+
+    parser.add_argument(
+        "--lm-data-valid",
+        type=str,
+        default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
+        help="LM validation data",
+    )
+
+    parser.add_argument(
+        "--vocab-size",
+        type=int,
+        default=500,
+        help="Vocabulary size of the model",
+    )
+
+    parser.add_argument(
+        "--num-layers",
+        type=int,
+        default=12,
+        help="Number of Transformer layers in the model",
+    )
+
+    parser.add_argument(
+        "--tie-weights",
+        type=str2bool,
+        default=True,
+        help="""True to share the weights between the input embedding layer and the
+        last output linear layer
+        """,
+    )
+
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="The seed for random generators intended for reproducibility",
+    )
+
+    return parser
+
+
+def get_params() -> AttributeDict:
+    """Return a dict containing training parameters."""
+
+    params = AttributeDict(
+        {
+            "max_sent_len": 200,
+            "sos_id": 1,
+            "eos_id": 1,
+            "blank_id": 0,
+            "lr": 1e-3,
+            "weight_decay": 1e-6,
+            "best_train_loss": float("inf"),
+            "best_valid_loss": float("inf"),
+            "best_train_epoch": -1,
+            "best_valid_epoch": -1,
+            "batch_idx_train": 0,
+            "log_interval": 200,
+            "reset_interval": 2000,
+            "valid_interval": 1000,
+            "nhead": 8,
+            "embedding_dim": 768,
+            "encoder_dim": 768,
+            "dim_feedforward": 2048,
+            "dropout": 0.1,
+            "env_info": get_env_info(),
+        }
+    )
+    return params
+
+
+def load_checkpoint_if_available(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+    """Load checkpoint from file.
+
+    If params.start_epoch is positive, it will load the checkpoint from
+    `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+    Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+    it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+    and `best_valid_loss` in `params`.
+
+    Args:
+      params:
+        The return value of :func:`get_params`.
+      model:
+        The training model.
+      optimizer:
+        The optimizer that we are using.
+      scheduler:
+        The learning rate scheduler we are using.
+    Returns:
+      Return None.
+    """
+    if params.start_epoch <= 0:
+        return
+
+    filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+    logging.info(f"Loading checkpoint: {filename}")
+    saved_params = load_checkpoint(
+        filename,
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+    )
+
+    keys = [
+        "best_train_epoch",
+        "best_valid_epoch",
+        "batch_idx_train",
+        "best_train_loss",
+        "best_valid_loss",
+    ]
+    for k in keys:
+        params[k] = saved_params[k]
+
+    return saved_params
+
+
+def save_checkpoint(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: Optional[torch.optim.Optimizer] = None,
+    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+    rank: int = 0,
+) -> None:
+    """Save model, optimizer, scheduler and training stats to file.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The training model.
+    """
+    if rank != 0:
+        return
+    filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+    save_checkpoint_impl(
+        filename=filename,
+        model=model,
+        params=params,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        rank=rank,
+    )
+
+    if params.best_train_epoch == params.cur_epoch:
+        best_train_filename = params.exp_dir / "best-train-loss.pt"
+        copyfile(src=filename, dst=best_train_filename)
+
+    if params.best_valid_epoch == params.cur_epoch:
+        best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+        copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+    model: nn.Module,
+    x: torch.Tensor,
+    y: torch.Tensor,
+    sentence_lengths: torch.Tensor,
+    is_training: bool,
+) -> Tuple[torch.Tensor, MetricsTracker]:
+    """Compute the negative log-likelihood loss given a model and its input.
+    Args:
+      model:
+        The NN model,
+      x:
+        A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
+        each row starts with SOS ID.
+      y:
+        A 2-D tensor. Each row is a shifted version of the corresponding row
+        in `x` but ends with an EOS ID (before padding).
+     sentence_lengths:
+       A 1-D tensor containing number of tokens of each sentence
+       before padding.
+     is_training:
+       True for training. False for validation.
+    """
+    with torch.set_grad_enabled(is_training):
+        device = model.device
+        x = x.to(device)
+        y = y.to(device)
+        sentence_lengths = sentence_lengths.to(device)
+
+        nll = model(x, y, sentence_lengths)
+        loss = nll.sum()
+
+        num_tokens = sentence_lengths.sum().item()
+
+        loss_info = MetricsTracker()
+        # Note: Due to how MetricsTracker() is designed,
+        # we use "frames" instead of "num_tokens" as a key here
+        loss_info["frames"] = num_tokens
+        loss_info["loss"] = loss.detach().item()
+    return loss, loss_info
+
+
+def compute_validation_loss(
+    params: AttributeDict,
+    model: nn.Module,
+    valid_dl: torch.utils.data.DataLoader,
+    world_size: int = 1,
+) -> MetricsTracker:
+    """Run the validation process. The validation loss
+    is saved in `params.valid_loss`.
+    """
+    model.eval()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(valid_dl):
+        x, y, sentence_lengths = batch
+        with torch.cuda.amp.autocast(enabled=params.use_fp16):
+            loss, loss_info = compute_loss(
+                model=model,
+                x=x,
+                y=y,
+                sentence_lengths=sentence_lengths,
+                is_training=False,
+            )
+
+        assert loss.requires_grad is False
+        tot_loss = tot_loss + loss_info
+
+    if world_size > 1:
+        tot_loss.reduce(loss.device)
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    if loss_value < params.best_valid_loss:
+        params.best_valid_epoch = params.cur_epoch
+        params.best_valid_loss = loss_value
+
+    return tot_loss
+
+
+def train_one_epoch(
+    params: AttributeDict,
+    model: nn.Module,
+    optimizer: torch.optim.Optimizer,
+    train_dl: torch.utils.data.DataLoader,
+    valid_dl: torch.utils.data.DataLoader,
+    tb_writer: Optional[SummaryWriter] = None,
+    world_size: int = 1,
+) -> None:
+    """Train the model for one epoch.
+
+    The training loss from the mean of all sentences is saved in
+    `params.train_loss`. It runs the validation process every
+    `params.valid_interval` batches.
+
+    Args:
+      params:
+        It is returned by :func:`get_params`.
+      model:
+        The model for training.
+      optimizer:
+        The optimizer we are using.
+      train_dl:
+        Dataloader for the training dataset.
+      valid_dl:
+        Dataloader for the validation dataset.
+      tb_writer:
+        Writer to write log messages to tensorboard.
+      world_size:
+        Number of nodes in DDP training. If it is 1, DDP is disabled.
+    """
+    model.train()
+
+    tot_loss = MetricsTracker()
+
+    for batch_idx, batch in enumerate(train_dl):
+        params.batch_idx_train += 1
+        x, y, sentence_lengths = batch
+        batch_size = x.size(0)
+        with torch.cuda.amp.autocast(enabled=params.use_fp16):
+            loss, loss_info = compute_loss(
+                model=model,
+                x=x,
+                y=y,
+                sentence_lengths=sentence_lengths,
+                is_training=True,
+            )
+
+        # summary stats
+        tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+        optimizer.zero_grad()
+        loss.backward()
+        clip_grad_norm_(model.parameters(), 5.0, 2.0)
+        optimizer.step()
+
+        if batch_idx % params.log_interval == 0:
+            # Note: "frames" here means "num_tokens"
+            this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
+            tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
+
+            logging.info(
+                f"Epoch {params.cur_epoch}, "
+                f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
+                f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
+                f"batch size: {batch_size}"
+            )
+
+            if tb_writer is not None:
+                loss_info.write_summary(
+                    tb_writer, "train/current_", params.batch_idx_train
+                )
+                tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+                tb_writer.add_scalar(
+                    "train/current_ppl", this_batch_ppl, params.batch_idx_train
+                )
+
+                tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
+
+        if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+            logging.info("Computing validation loss")
+
+            valid_info = compute_validation_loss(
+                params=params,
+                model=model,
+                valid_dl=valid_dl,
+                world_size=world_size,
+            )
+            model.train()
+
+            valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
+            logging.info(
+                f"Epoch {params.cur_epoch}, validation: {valid_info}, "
+                f"ppl: {valid_ppl}"
+            )
+
+            if tb_writer is not None:
+                valid_info.write_summary(
+                    tb_writer, "train/valid_", params.batch_idx_train
+                )
+
+                tb_writer.add_scalar(
+                    "train/valid_ppl", valid_ppl, params.batch_idx_train
+                )
+
+    loss_value = tot_loss["loss"] / tot_loss["frames"]
+    params.train_loss = loss_value
+    if params.train_loss < params.best_train_loss:
+        params.best_train_epoch = params.cur_epoch
+        params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+    """
+    Args:
+      rank:
+        It is a value between 0 and `world_size-1`, which is
+        passed automatically by `mp.spawn()` in :func:`main`.
+        The node with rank 0 is responsible for saving checkpoint.
+      world_size:
+        Number of GPUs for DDP training.
+      args:
+        The return value of get_parser().parse_args()
+    """
+    params = get_params()
+    params.update(vars(args))
+    is_distributed = world_size > 1
+
+    fix_random_seed(params.seed)
+    if is_distributed:
+        setup_dist(rank, world_size, params.master_port)
+
+    setup_logger(f"{params.exp_dir}/log/log-train")
+    logging.info("Training started")
+    logging.info(params)
+
+    if args.tensorboard and rank == 0:
+        tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+    else:
+        tb_writer = None
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", rank)
+
+    logging.info(f"Device: {device}")
+
+    logging.info("About to create model")
+    model = TransformerLM(
+        vocab_size=params.vocab_size,
+        d_model=params.encoder_dim,
+        embedding_dim=params.embedding_dim,
+        dim_feedforward=params.dim_feedforward,
+        nhead=params.nhead,
+        num_layers=params.num_layers,
+        tie_weights=params.tie_weights,
+        params=params,
+    )
+
+    num_param = sum([p.numel() for p in model.parameters()])
+    logging.info(f"Number of model parameters: {num_param}")
+
+    checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+    model.to(device)
+    if is_distributed:
+        model = DDP(model, device_ids=[rank])
+
+    model.device = device
+
+    optimizer = optim.Adam(
+        model.parameters(),
+        lr=params.lr,
+        weight_decay=params.weight_decay,
+    )
+    if checkpoints:
+        logging.info("Load optimizer state_dict from checkpoint")
+        optimizer.load_state_dict(checkpoints["optimizer"])
+
+    logging.info(f"Loading LM training data from {params.lm_data}")
+    train_dl = get_dataloader(
+        filename=params.lm_data,
+        is_distributed=is_distributed,
+        params=params,
+    )
+
+    logging.info(f"Loading LM validation data from {params.lm_data_valid}")
+    valid_dl = get_dataloader(
+        filename=params.lm_data_valid,
+        is_distributed=is_distributed,
+        params=params,
+    )
+
+    # Note: No learning rate scheduler is used here
+    for epoch in range(params.start_epoch, params.num_epochs):
+        if is_distributed:
+            train_dl.sampler.set_epoch(epoch)
+
+        params.cur_epoch = epoch
+
+        train_one_epoch(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            train_dl=train_dl,
+            valid_dl=valid_dl,
+            tb_writer=tb_writer,
+            world_size=world_size,
+        )
+
+        save_checkpoint(
+            params=params,
+            model=model,
+            optimizer=optimizer,
+            rank=rank,
+        )
+
+    logging.info("Done!")
+
+    if is_distributed:
+        torch.distributed.barrier()
+        cleanup_dist()
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    world_size = args.world_size
+    assert world_size >= 1
+    if world_size > 1:
+        mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+    else:
+        run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+    main()
diff --git a/icefall/utils.py b/icefall/utils.py
index c502cb4d8..4aa8197ad 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -1,5 +1,6 @@
-# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang
-#                                                    Mingshuang Luo)
+# Copyright      2021  Xiaomi Corp.        (authors: Fangjun Kuang,
+#                                                    Mingshuang Luo,
+#                                                    Zengwei Yao)
 #
 # See ../../LICENSE for clarification regarding multiple authors
 #
@@ -130,9 +131,7 @@ def setup_logger(
         formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s"  # noqa
         log_filename = f"{log_filename}-{date_time}-{rank}"
     else:
-        formatter = (
-            "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-        )
+        formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
         log_filename = f"{log_filename}-{date_time}"
 
     os.makedirs(os.path.dirname(log_filename), exist_ok=True)
@@ -177,11 +176,13 @@ class AttributeDict(dict):
 
 
 def encode_supervisions(
-    supervisions: dict, subsampling_factor: int
-) -> Tuple[torch.Tensor, List[str]]:
+    supervisions: dict,
+    subsampling_factor: int,
+    token_ids: Optional[List[List[int]]] = None,
+) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]:
     """
     Encodes Lhotse's ``batch["supervisions"]`` dict into
-    a pair of torch Tensor, and a list of transcription strings.
+    a pair of torch Tensor, and a list of transcription strings or token indexes
 
     The supervision tensor has shape ``(batch_size, 3)``.
     Its second dimension contains information about sequence index [0],
@@ -194,18 +195,30 @@ def encode_supervisions(
     supervision_segments = torch.stack(
         (
             supervisions["sequence_idx"],
-            supervisions["start_frame"] // subsampling_factor,
-            supervisions["num_frames"] // subsampling_factor,
+            torch.div(
+                supervisions["start_frame"],
+                subsampling_factor,
+                rounding_mode="floor",
+            ),
+            torch.div(
+                supervisions["num_frames"],
+                subsampling_factor,
+                rounding_mode="floor",
+            ),
         ),
         1,
     ).to(torch.int32)
 
     indices = torch.argsort(supervision_segments[:, 2], descending=True)
     supervision_segments = supervision_segments[indices]
-    texts = supervisions["text"]
-    texts = [texts[idx] for idx in indices]
 
-    return supervision_segments, texts
+    if token_ids is None:
+        texts = supervisions["text"]
+        res = [texts[idx] for idx in indices]
+    else:
+        res = [token_ids[idx] for idx in indices]
+
+    return supervision_segments, res
 
 
 def get_texts(
@@ -280,13 +293,9 @@ def get_texts_with_timestamp(
     """
     if isinstance(best_paths.aux_labels, k2.RaggedTensor):
         all_aux_shape = (
-            best_paths.arcs.shape()
-            .remove_axis(1)
-            .compose(best_paths.aux_labels.shape)
-        )
-        all_aux_labels = k2.RaggedTensor(
-            all_aux_shape, best_paths.aux_labels.values
+            best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape)
         )
+        all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values)
         # remove 0's and -1's.
         aux_labels = best_paths.aux_labels.remove_values_leq(0)
         # TODO: change arcs.shape() to arcs.shape
@@ -355,9 +364,7 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
     # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here
     token_shape = best_paths.arcs.shape().remove_axis(1)
     # token_shape has axes [fsa][arc]
-    tokens = k2.RaggedTensor(
-        token_shape, getattr(best_paths, kind).contiguous()
-    )
+    tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous())
     tokens = tokens.remove_values_eq(-1)
     return tokens.tolist()
 
@@ -447,11 +454,32 @@ def store_transcripts_and_timestamps(
         for cut_id, ref, hyp, time_ref, time_hyp in texts:
             print(f"{cut_id}:\tref={ref}", file=f)
             print(f"{cut_id}:\thyp={hyp}", file=f)
+
             if len(time_ref) > 0:
-                s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
+                if isinstance(time_ref[0], tuple):
+                    # each element is  pair
+                    s = (
+                        "["
+                        + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref])
+                        + "]"
+                    )
+                else:
+                    # each element is a float number
+                    s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]"
                 print(f"{cut_id}:\ttimestamp_ref={s}", file=f)
-            s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
-            print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
+
+            if len(time_hyp) > 0:
+                if isinstance(time_hyp[0], tuple):
+                    # each element is  pair
+                    s = (
+                        "["
+                        + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp])
+                        + "]"
+                    )
+                else:
+                    # each element is a float number
+                    s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]"
+                print(f"{cut_id}:\ttimestamp_hyp={s}", file=f)
 
 
 def write_error_stats(
@@ -578,9 +606,7 @@ def write_error_stats(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -590,9 +616,7 @@ def write_error_stats(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -606,9 +630,7 @@ def write_error_stats(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -624,9 +646,18 @@ def write_error_stats(
 def write_error_stats_with_timestamps(
     f: TextIO,
     test_set_name: str,
-    results: List[Tuple[str, List[str], List[str], List[float], List[float]]],
+    results: List[
+        Tuple[
+            str,
+            List[str],
+            List[str],
+            List[Union[float, Tuple[float, float]]],
+            List[Union[float, Tuple[float, float]]],
+        ]
+    ],
     enable_log: bool = True,
-) -> Tuple[float, float, float]:
+    with_end_time: bool = False,
+) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]:
     """Write statistics based on predicted results and reference transcripts
     as well as their timestamps.
 
@@ -659,6 +690,8 @@ def write_error_stats_with_timestamps(
       enable_log:
         If True, also print detailed WER to the console.
         Otherwise, it is written only to the given file.
+      with_end_time:
+        Whether use end timestamps.
 
     Returns:
       Return total word error rate and mean delay.
@@ -676,8 +709,8 @@ def write_error_stats_with_timestamps(
     all_delay = []
     for cut_id, ref, hyp, time_ref, time_hyp in results:
         ali = kaldialign.align(ref, hyp, ERR)
-        has_time_ref = len(time_ref) > 0
-        if has_time_ref:
+        has_time = len(time_ref) > 0 and len(time_hyp) > 0
+        if has_time:
             # pointer to timestamp_hyp
             p_hyp = 0
             # pointer to timestamp_ref
@@ -686,28 +719,36 @@ def write_error_stats_with_timestamps(
             if ref_word == ERR:
                 ins[hyp_word] += 1
                 words[hyp_word][3] += 1
-                if has_time_ref:
+                if has_time:
                     p_hyp += 1
             elif hyp_word == ERR:
                 dels[ref_word] += 1
                 words[ref_word][4] += 1
-                if has_time_ref:
+                if has_time:
                     p_ref += 1
             elif hyp_word != ref_word:
                 subs[(ref_word, hyp_word)] += 1
                 words[ref_word][1] += 1
                 words[hyp_word][2] += 1
-                if has_time_ref:
+                if has_time:
                     p_hyp += 1
                     p_ref += 1
             else:
                 words[ref_word][0] += 1
                 num_corr += 1
-                if has_time_ref:
-                    all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
+                if has_time:
+                    if with_end_time:
+                        all_delay.append(
+                            (
+                                time_hyp[p_hyp][0] - time_ref[p_ref][0],
+                                time_hyp[p_hyp][1] - time_ref[p_ref][1],
+                            )
+                        )
+                    else:
+                        all_delay.append(time_hyp[p_hyp] - time_ref[p_ref])
                     p_hyp += 1
                     p_ref += 1
-        if has_time_ref:
+        if has_time:
             assert p_hyp == len(hyp), (p_hyp, len(hyp))
             assert p_ref == len(ref), (p_ref, len(ref))
 
@@ -716,16 +757,39 @@ def write_error_stats_with_timestamps(
     ins_errs = sum(ins.values())
     del_errs = sum(dels.values())
     tot_errs = sub_errs + ins_errs + del_errs
-    tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
+    tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len))
 
-    mean_delay = "inf"
-    var_delay = "inf"
+    if with_end_time:
+        mean_delay = (float("inf"), float("inf"))
+        var_delay = (float("inf"), float("inf"))
+    else:
+        mean_delay = float("inf")
+        var_delay = float("inf")
     num_delay = len(all_delay)
     if num_delay > 0:
-        mean_delay = sum(all_delay) / num_delay
-        var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
-        mean_delay = "%.3f" % mean_delay
-        var_delay = "%.3f" % var_delay
+        if with_end_time:
+            all_delay_start = [i[0] for i in all_delay]
+            mean_delay_start = sum(all_delay_start) / num_delay
+            var_delay_start = (
+                sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay
+            )
+
+            all_delay_end = [i[1] for i in all_delay]
+            mean_delay_end = sum(all_delay_end) / num_delay
+            var_delay_end = (
+                sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay
+            )
+
+            mean_delay = (
+                float("%.3f" % mean_delay_start),
+                float("%.3f" % mean_delay_end),
+            )
+            var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end))
+        else:
+            mean_delay = sum(all_delay) / num_delay
+            var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay
+            mean_delay = float("%.3f" % mean_delay)
+            var_delay = float("%.3f" % var_delay)
 
     if enable_log:
         logging.info(
@@ -734,7 +798,8 @@ def write_error_stats_with_timestamps(
             f"{del_errs} del, {sub_errs} sub ]"
         )
         logging.info(
-            f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} "  # noqa
+            f"[{test_set_name}] %symbol-delay mean (s): "
+            f"{mean_delay}, variance: {var_delay} "  # noqa
             f"computed on {num_delay} correct words"
         )
 
@@ -783,9 +848,7 @@ def write_error_stats_with_timestamps(
             f"{cut_id}:\t"
             + " ".join(
                 (
-                    ref_word
-                    if ref_word == hyp_word
-                    else f"({ref_word}->{hyp_word})"
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                     for ref_word, hyp_word in ali
                 )
             ),
@@ -795,9 +858,7 @@ def write_error_stats_with_timestamps(
     print("", file=f)
     print("SUBSTITUTIONS: count ref -> hyp", file=f)
 
-    for count, (ref, hyp) in sorted(
-        [(v, k) for k, v in subs.items()], reverse=True
-    ):
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
         print(f"{count}   {ref} -> {hyp}", file=f)
 
     print("", file=f)
@@ -811,9 +872,7 @@ def write_error_stats_with_timestamps(
         print(f"{count}   {hyp}", file=f)
 
     print("", file=f)
-    print(
-        "PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f
-    )
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
     for _, word, counts in sorted(
         [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
     ):
@@ -823,7 +882,8 @@ def write_error_stats_with_timestamps(
         hyp_count = corr + hyp_sub + ins
 
         print(f"{word}   {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
-    return float(tot_err_rate), float(mean_delay), float(var_delay)
+
+    return tot_err_rate, mean_delay, var_delay
 
 
 class MetricsTracker(collections.defaultdict):
@@ -883,9 +943,7 @@ class MetricsTracker(collections.defaultdict):
             if k == "frames" or k == "utterances":
                 continue
             norm_value = (
-                float(v) / num_frames
-                if "utt_" not in k
-                else float(v) / num_utterances
+                float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
             )
             ans.append((k, norm_value))
         return ans
@@ -919,9 +977,7 @@ class MetricsTracker(collections.defaultdict):
             tb_writer.add_scalar(prefix + k, v, batch_idx)
 
 
-def concat(
-    ragged: k2.RaggedTensor, value: int, direction: str
-) -> k2.RaggedTensor:
+def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
     """Prepend a value to the beginning of each sublist or append a value.
     to the end of each sublist.
 
@@ -1039,10 +1095,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
     assert lengths.ndim == 1, lengths.ndim
     max_len = max(max_len, lengths.max())
     n = lengths.size(0)
+    seq_range = torch.arange(0, max_len, device=lengths.device)
+    expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
 
-    expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
-
-    return expaned_lengths >= lengths.unsqueeze(1)
+    return expaned_lengths >= lengths.unsqueeze(-1)
 
 
 # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
@@ -1093,9 +1149,7 @@ def linf_norm(x):
     return torch.max(torch.abs(x))
 
 
-def measure_weight_norms(
-    model: nn.Module, norm: str = "l2"
-) -> Dict[str, float]:
+def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]:
     """
     Compute the norms of the model's parameters.
 
@@ -1118,9 +1172,7 @@ def measure_weight_norms(
         return norms
 
 
-def measure_gradient_norms(
-    model: nn.Module, norm: str = "l1"
-) -> Dict[str, float]:
+def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]:
     """
     Compute the norms of the gradients for each of model's parameters.
 
@@ -1254,6 +1306,31 @@ def tokenize_by_bpe_model(
     return txt_with_bpe
 
 
+def tokenize_by_CJK_char(line: str) -> str:
+    """
+    Tokenize a line of text with CJK char.
+
+    Note: All return charaters will be upper case.
+
+    Example:
+      input = "你好世界是 hello world 的中文"
+      output = "你 好 世 界 是 HELLO WORLD 的 中 文"
+
+    Args:
+      line:
+        The input text.
+
+    Return:
+      A new string tokenize by CJK char.
+    """
+    # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
+    pattern = re.compile(
+        r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
+    )
+    chars = pattern.split(line.strip().upper())
+    return " ".join([w.strip() for w in chars if w.strip()])
+
+
 def display_and_save_batch(
     batch: dict,
     params: AttributeDict,
@@ -1326,7 +1403,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
       List of timestamp of each word.
     """
     start_token = b"\xe2\x96\x81".decode()  # '_'
-    assert len(tokens) == len(timestamp)
+    assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
     ans = []
     for i in range(len(tokens)):
         flag = False
@@ -1347,10 +1424,9 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
 
 def parse_hyp_and_timestamp(
     res: DecodingResults,
-    decoding_method: str,
-    sp: spm.SentencePieceProcessor,
     subsampling_factor: int,
     frame_shift_ms: float = 10,
+    sp: Optional[spm.SentencePieceProcessor] = None,
     word_table: Optional[k2.SymbolTable] = None,
 ) -> Tuple[List[List[str]], List[List[float]]]:
     """Parse hypothesis and timestamp.
@@ -1358,56 +1434,32 @@ def parse_hyp_and_timestamp(
     Args:
       res:
         A DecodingResults object.
-      decoding_method:
-        Possible values are:
-          - greedy_search
-          - beam_search
-          - modified_beam_search
-          - fast_beam_search
-          - fast_beam_search_LG
-          - fast_beam_search_nbest
-          - fast_beam_search_nbest_oracle
-          - fast_beam_search_nbest_LG
-      sp:
-        The BPE model.
       subsampling_factor:
         The integer subsampling factor.
       frame_shift_ms:
         The float frame shift used for feature extraction.
+      sp:
+        The BPE model.
       word_table:
         The word symbol table.
 
     Returns:
        Return a list of hypothesis and timestamp.
     """
-    assert decoding_method in (
-        "greedy_search",
-        "beam_search",
-        "fast_beam_search",
-        "fast_beam_search_LG",
-        "fast_beam_search_nbest",
-        "fast_beam_search_nbest_LG",
-        "fast_beam_search_nbest_oracle",
-        "modified_beam_search",
-    )
-
     hyps = []
     timestamps = []
 
     N = len(res.hyps)
     assert len(res.timestamps) == N, (len(res.timestamps), N)
     use_word_table = False
-    if (
-        decoding_method == "fast_beam_search_nbest_LG"
-        and decoding_method == "fast_beam_search_LG"
-    ):
-        assert word_table is not None
+    if word_table is not None:
+        assert sp is None
         use_word_table = True
+    else:
+        assert sp is not None and word_table is None
 
     for i in range(N):
-        time = convert_timestamp(
-            res.timestamps[i], subsampling_factor, frame_shift_ms
-        )
+        time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms)
         if use_word_table:
             words = [word_table[i] for i in res.hyps[i]]
         else:
@@ -1434,3 +1486,337 @@ def is_module_available(*modules: str) -> bool:
     import importlib
 
     return all(importlib.util.find_spec(m) is not None for m in modules)
+
+
+def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
+    """For the uneven-sized batch, the total duration after padding would possibly
+    cause OOM. Hence, for each batch, which is sorted descendingly by length,
+    we simply drop the last few shortest samples, so that the retained total frames
+    (after padding) would not exceed the given allow_max_frames.
+
+    Args:
+      batch:
+        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+        for the content in it.
+      allowed_max_frames:
+        The allowed max number of frames in batch.
+    """
+    features = batch["inputs"]
+    supervisions = batch["supervisions"]
+
+    N, T, _ = features.size()
+    assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
+    keep_num_utt = allowed_max_frames // T
+
+    if keep_num_utt >= N:
+        return batch
+
+    # Note: we assume the samples in batch is sorted descendingly by length
+    logging.info(
+        f"Filtering uneven-sized batch, original batch size is {N}, "
+        f"retained batch size is {keep_num_utt}."
+    )
+    batch["inputs"] = features[:keep_num_utt]
+    for k, v in supervisions.items():
+        assert len(v) == N, (len(v), N)
+        batch["supervisions"][k] = v[:keep_num_utt]
+
+    return batch
+
+
+def parse_bpe_start_end_pairs(
+    tokens: List[str], is_first_token: List[bool]
+) -> List[Tuple[int, int]]:
+    """Parse pairs of start and end frame indexes for each word.
+
+    Args:
+      tokens:
+        List of BPE tokens.
+      is_first_token:
+        List of bool values, which indicates whether it is the first token,
+        i.e., not repeat or blank.
+
+    Returns:
+      List of (start-frame-index, end-frame-index) pairs for each word.
+    """
+    assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token))
+
+    start_token = b"\xe2\x96\x81".decode()  # '_'
+    blank_token = ""
+
+    non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token]
+    num_non_blank = len(non_blank_idx)
+
+    pairs = []
+    start = -1
+    end = -1
+    for j in range(num_non_blank):
+        # The index in all frames
+        i = non_blank_idx[j]
+
+        found_start = False
+        if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)):
+            found_start = True
+            if tokens[i] == start_token:
+                if j == num_non_blank - 1:
+                    # It is the last non-blank token
+                    found_start = False
+                elif is_first_token[non_blank_idx[j + 1]] and tokens[
+                    non_blank_idx[j + 1]
+                ].startswith(start_token):
+                    # The next not-blank token is a first-token and also starts with start_token
+                    found_start = False
+        if found_start:
+            start = i
+
+        if start != -1:
+            found_end = False
+            if j == num_non_blank - 1:
+                # It is the last non-blank token
+                found_end = True
+            elif is_first_token[non_blank_idx[j + 1]] and tokens[
+                non_blank_idx[j + 1]
+            ].startswith(start_token):
+                # The next not-blank token is a first-token and also starts with start_token
+                found_end = True
+            if found_end:
+                end = i
+
+        if start != -1 and end != -1:
+            if not all([tokens[t] == start_token for t in range(start, end + 1)]):
+                # except the case of all start_token
+                pairs.append((start, end))
+            # Reset start and end
+            start = -1
+            end = -1
+
+    return pairs
+
+
+def parse_bpe_timestamps_and_texts(
+    best_paths: k2.Fsa, sp: spm.SentencePieceProcessor
+) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
+    """Parse timestamps (frame indexes) and texts.
+
+    Args:
+      best_paths:
+        A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
+        containing multiple FSAs, which is expected to be the result
+        of k2.shortest_path (otherwise the returned values won't
+        be meaningful). Its attribtutes `labels` and `aux_labels`
+        are both BPE tokens.
+      sp:
+        The BPE model.
+
+    Returns:
+      utt_index_pairs:
+        A list of pair list. utt_index_pairs[i] is a list of
+        (start-frame-index, end-frame-index) pairs for each word in
+        utterance-i.
+      utt_words:
+        A list of str list. utt_words[i] is a word list of utterence-i.
+    """
+    shape = best_paths.arcs.shape().remove_axis(1)
+
+    # labels: [utt][arcs]
+    labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
+    # remove -1's.
+    labels = labels.remove_values_eq(-1)
+    labels = labels.tolist()
+
+    # aux_labels: [utt][arcs]
+    aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous())
+
+    # remove -1's.
+    all_aux_labels = aux_labels.remove_values_eq(-1)
+    # len(all_aux_labels[i]) is equal to the number of frames
+    all_aux_labels = all_aux_labels.tolist()
+
+    # remove 0's and -1's.
+    out_aux_labels = aux_labels.remove_values_leq(0)
+    # len(out_aux_labels[i]) is equal to the number of output BPE tokens
+    out_aux_labels = out_aux_labels.tolist()
+
+    utt_index_pairs = []
+    utt_words = []
+    for i in range(len(labels)):
+        tokens = sp.id_to_piece(labels[i])
+        words = sp.decode(out_aux_labels[i]).split()
+
+        # Indicates whether it is the first token, i.e., not-repeat and not-blank.
+        is_first_token = [a != 0 for a in all_aux_labels[i]]
+        index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token)
+        assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens)
+        utt_index_pairs.append(index_pairs)
+        utt_words.append(words)
+
+    return utt_index_pairs, utt_words
+
+
+def parse_timestamps_and_texts(
+    best_paths: k2.Fsa, word_table: k2.SymbolTable
+) -> Tuple[List[Tuple[int, int]], List[List[str]]]:
+    """Parse timestamps (frame indexes) and texts.
+
+    Args:
+      best_paths:
+        A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
+        containing multiple FSAs, which is expected to be the result
+        of k2.shortest_path (otherwise the returned values won't
+        be meaningful). Attribtute `labels` is the prediction unit,
+        e.g., phone or BPE tokens. Attribute `aux_labels` is the word index.
+      word_table:
+        The word symbol table.
+
+    Returns:
+      utt_index_pairs:
+        A list of pair list. utt_index_pairs[i] is a list of
+        (start-frame-index, end-frame-index) pairs for each word in
+        utterance-i.
+      utt_words:
+        A list of str list. utt_words[i] is a word list of utterence-i.
+    """
+    # [utt][words]
+    word_ids = get_texts(best_paths)
+
+    shape = best_paths.arcs.shape().remove_axis(1)
+
+    # labels: [utt][arcs]
+    labels = k2.RaggedTensor(shape, best_paths.labels.contiguous())
+    # remove -1's.
+    labels = labels.remove_values_eq(-1)
+    labels = labels.tolist()
+
+    # aux_labels: [utt][arcs]
+    aux_shape = shape.compose(best_paths.aux_labels.shape)
+    aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous())
+    aux_labels = aux_labels.tolist()
+
+    utt_index_pairs = []
+    utt_words = []
+    for i, (label, aux_label) in enumerate(zip(labels, aux_labels)):
+        num_arcs = len(label)
+        # The last arc of aux_label is the arc entering the final state
+        assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label))
+
+        index_pairs = []
+        start = -1
+        end = -1
+        for arc in range(num_arcs):
+            # len(aux_label[arc]) is 0 or 1
+            if label[arc] != 0 and len(aux_label[arc]) != 0:
+                if start != -1 and end != -1:
+                    index_pairs.append((start, end))
+                start = arc
+            if label[arc] != 0:
+                end = arc
+        if start != -1 and end != -1:
+            index_pairs.append((start, end))
+
+        words = [word_table[w] for w in word_ids[i]]
+        assert len(index_pairs) == len(words), (len(index_pairs), len(words))
+
+        utt_index_pairs.append(index_pairs)
+        utt_words.append(words)
+
+    return utt_index_pairs, utt_words
+
+
+def parse_fsa_timestamps_and_texts(
+    best_paths: k2.Fsa,
+    sp: Optional[spm.SentencePieceProcessor] = None,
+    word_table: Optional[k2.SymbolTable] = None,
+    subsampling_factor: int = 4,
+    frame_shift_ms: float = 10,
+) -> Tuple[List[Tuple[float, float]], List[List[str]]]:
+    """Parse timestamps (in seconds) and texts for given decoded fsa paths.
+    Currently it supports two cases:
+    (1) ctc-decoding, the attribtutes `labels` and `aux_labels`
+        are both BPE tokens. In this case, sp should be provided.
+    (2) HLG-based 1best, the attribtute `labels` is the prediction unit,
+        e.g., phone or BPE tokens; attribute `aux_labels` is the word index.
+        In this case, word_table should be provided.
+
+    Args:
+      best_paths:
+        A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
+        containing multiple FSAs, which is expected to be the result
+        of k2.shortest_path (otherwise the returned values won't
+        be meaningful).
+      sp:
+        The BPE model.
+      word_table:
+        The word symbol table.
+      subsampling_factor:
+        The subsampling factor of the model.
+      frame_shift_ms:
+        Frame shift in milliseconds between two contiguous frames.
+
+    Returns:
+      utt_time_pairs:
+        A list of pair list. utt_time_pairs[i] is a list of
+        (start-time, end-time) pairs for each word in
+        utterance-i.
+      utt_words:
+        A list of str list. utt_words[i] is a word list of utterence-i.
+    """
+    if sp is not None:
+        assert word_table is None, "word_table is not needed if sp is provided."
+        utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(
+            best_paths=best_paths, sp=sp
+        )
+    elif word_table is not None:
+        assert sp is None, "sp is not needed if word_table is provided."
+        utt_index_pairs, utt_words = parse_timestamps_and_texts(
+            best_paths=best_paths, word_table=word_table
+        )
+    else:
+        raise ValueError("Either sp or word_table should be provided.")
+
+    utt_time_pairs = []
+    for utt in utt_index_pairs:
+        start = convert_timestamp(
+            frames=[i[0] for i in utt],
+            subsampling_factor=subsampling_factor,
+            frame_shift_ms=frame_shift_ms,
+        )
+        end = convert_timestamp(
+            # The duration in frames is (end_frame_index - start_frame_index + 1)
+            frames=[i[1] + 1 for i in utt],
+            subsampling_factor=subsampling_factor,
+            frame_shift_ms=frame_shift_ms,
+        )
+        utt_time_pairs.append(list(zip(start, end)))
+
+    return utt_time_pairs, utt_words
+
+
+# Copied from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
+def is_cjk(character):
+    """
+    Python port of Moses' code to check for CJK character.
+
+    >>> is_cjk(u'\u33fe')
+    True
+    >>> is_cjk(u'\uFE5F')
+    False
+
+    :param character: The character that needs to be checked.
+    :type character: char
+    :return: bool
+    """
+    return any(
+        [
+            start <= ord(character) <= end
+            for start, end in [
+                (4352, 4607),
+                (11904, 42191),
+                (43072, 43135),
+                (44032, 55215),
+                (63744, 64255),
+                (65072, 65103),
+                (65381, 65500),
+                (131072, 196607),
+            ]
+        ]
+    )
diff --git a/pyproject.toml b/pyproject.toml
index b4f8c3377..3183055d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ profile = "black"
 skip = ["icefall/__init__.py"]
 
 [tool.black]
-line-length = 80
+line-length = 88
 exclude = '''
 /(
     \.git
diff --git a/requirements-ci.txt b/requirements-ci.txt
index b8e49899e..50d4e5e3f 100644
--- a/requirements-ci.txt
+++ b/requirements-ci.txt
@@ -22,5 +22,6 @@ typeguard==2.13.3
 multi_quantization
 
 onnx
+onnxmltools
 onnxruntime
 kaldifst
diff --git a/requirements.txt b/requirements.txt
index 5e32af853..a07f6b7c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+kaldifst
 kaldilm
 kaldialign
 sentencepiece>=0.1.96
diff --git a/setup.py b/setup.py
index 6c720e121..ccd2503ff 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,9 @@
 #!/usr/bin/env python3
 
-from setuptools import find_packages, setup
 from pathlib import Path
 
+from setuptools import find_packages, setup
+
 icefall_dir = Path(__file__).parent
 install_requires = (icefall_dir / "requirements.txt").read_text().splitlines()
 
diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py
index 511a11c23..34e829642 100644
--- a/test/test_checkpoint.py
+++ b/test/test_checkpoint.py
@@ -20,11 +20,7 @@ import pytest
 import torch
 import torch.nn as nn
 
-from icefall.checkpoint import (
-    average_checkpoints,
-    load_checkpoint,
-    save_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint
 
 
 @pytest.fixture
diff --git a/test/test_decode.py b/test/test_decode.py
index 97964ac67..4c2e192a7 100644
--- a/test/test_decode.py
+++ b/test/test_decode.py
@@ -23,6 +23,7 @@ You can run this file in one of the two ways:
 """
 
 import k2
+
 from icefall.decode import Nbest
 
 
diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py
index ccfb57d49..10443cf22 100644
--- a/test/test_graph_compiler.py
+++ b/test/test_graph_compiler.py
@@ -154,9 +154,7 @@ class TestCtcTrainingGraphCompiler(object):
         fsas = k2.Fsa.from_fsas([fsa1, fsa2])
 
         decoding_graph = k2.arc_sort(decoding_graph)
-        lattice = k2.intersect(
-            decoding_graph, fsas, treat_epsilons_specially=False
-        )
+        lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False)
         lattice = k2.connect(lattice)
 
         aux_labels0 = lattice[0].aux_labels[:-1]
diff --git a/test/test_lexicon.py b/test/test_lexicon.py
index 69867efc7..b1beab3f6 100755
--- a/test/test_lexicon.py
+++ b/test/test_lexicon.py
@@ -112,7 +112,7 @@ def uniq_lexicon_test():
     # But there is no word "ca" in the lexicon, so our
     # implementation returns the id of ""
     print(token_ids, expected_token_ids)
-    assert token_ids.tolist() == [[sp.unk_id()]]
+    assert token_ids.tolist() == [[sp.piece_to_id("▁"), sp.unk_id()]]
 
     # case 3: With OOV
     texts = ["foo"]
diff --git a/test/test_parse_timestamp.py b/test/test_parse_timestamp.py
new file mode 100755
index 000000000..92bfb49c6
--- /dev/null
+++ b/test/test_parse_timestamp.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+# Copyright      2023  Xiaomi Corp.        (authors: Zengwei Yao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+
+import k2
+import sentencepiece as spm
+import torch
+
+from icefall.lexicon import Lexicon
+from icefall.utils import parse_bpe_timestamps_and_texts, parse_timestamps_and_texts
+
+ICEFALL_DIR = Path(__file__).resolve().parent.parent
+
+
+def test_parse_bpe_timestamps_and_texts():
+    lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500"
+    if not lang_dir.is_dir():
+        print(f"{lang_dir} does not exist.")
+        return
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(str(lang_dir / "bpe.model"))
+
+    text_1 = "HELLO WORLD"
+    token_ids_1 = sp.encode(text_1, out_type=int)
+    # out_type=str: ['_HE', 'LL', 'O', '_WORLD']
+    # out_type=int: [22, 58, 24, 425]
+
+    # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0]
+    labels_1 = (
+        token_ids_1[0:1] * 2
+        + token_ids_1[1:3]
+        + [0] * 2
+        + token_ids_1[3:4] * 3
+        + [0] * 2
+    )
+    # [22, 0, 58, 24, 0, 0, 425, 0, 0, 0, 0, -1]
+    aux_labels_1 = (
+        token_ids_1[0:1]
+        + [0]
+        + token_ids_1[1:3]
+        + [0] * 2
+        + token_ids_1[3:4]
+        + [0] * 4
+        + [-1]
+    )
+    fsa_1 = k2.linear_fsa(labels_1)
+    fsa_1.aux_labels = torch.tensor(aux_labels_1).to(torch.int32)
+
+    text_2 = "SAY GOODBYE"
+    token_ids_2 = sp.encode(text_2, out_type=int)
+    # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E']
+    # out_type=int: [289, 286, 41, 16, 11]
+
+    # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0]
+    labels_2 = (
+        token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2
+    )
+    # [289, 0, 0, 286, 0, 41, 16, 11, 0, 0, -1]
+    aux_labels_2 = (
+        token_ids_2[0:1]
+        + [0] * 2
+        + token_ids_2[1:2]
+        + [0]
+        + token_ids_2[2:5]
+        + [0] * 2
+        + [-1]
+    )
+    fsa_2 = k2.linear_fsa(labels_2)
+    fsa_2.aux_labels = torch.tensor(aux_labels_2).to(torch.int32)
+
+    fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2])
+
+    utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(fsa_vec, sp)
+    assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0]
+    assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0]
+    assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1]
+    assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1]
+
+
+def test_parse_timestamps_and_texts():
+    lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500"
+    if not lang_dir.is_dir():
+        print(f"{lang_dir} does not exist.")
+        return
+
+    lexicon = Lexicon(lang_dir)
+
+    sp = spm.SentencePieceProcessor()
+    sp.load(str(lang_dir / "bpe.model"))
+    word_table = lexicon.word_table
+
+    text_1 = "HELLO WORLD"
+    token_ids_1 = sp.encode(text_1, out_type=int)
+    # out_type=str: ['_HE', 'LL', 'O', '_WORLD']
+    # out_type=int: [22, 58, 24, 425]
+    word_ids_1 = [word_table[s] for s in text_1.split()]  # [79677, 196937]
+    # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0]
+    labels_1 = (
+        token_ids_1[0:1] * 2
+        + token_ids_1[1:3]
+        + [0] * 2
+        + token_ids_1[3:4] * 3
+        + [0] * 2
+    )
+    # [[79677], [], [], [], [], [], [196937], [], [], [], [], []]
+    aux_labels_1 = [word_ids_1[0:1]] + [[]] * 5 + [word_ids_1[1:2]] + [[]] * 5
+
+    fsa_1 = k2.linear_fsa(labels_1)
+    fsa_1.aux_labels = k2.RaggedTensor(aux_labels_1)
+
+    text_2 = "SAY GOODBYE"
+    token_ids_2 = sp.encode(text_2, out_type=int)
+    # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E']
+    # out_type=int: [289, 286, 41, 16, 11]
+    word_ids_2 = [word_table[s] for s in text_2.split()]  # [154967, 72079]
+    # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0]
+    labels_2 = (
+        token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2
+    )
+    # [[154967], [], [], [72079], [], [], [], [], [], [], []]
+    aux_labels_2 = [word_ids_2[0:1]] + [[]] * 2 + [word_ids_2[1:2]] + [[]] * 7
+
+    fsa_2 = k2.linear_fsa(labels_2)
+    fsa_2.aux_labels = k2.RaggedTensor(aux_labels_2)
+
+    fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2])
+
+    utt_index_pairs, utt_words = parse_timestamps_and_texts(fsa_vec, word_table)
+    assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0]
+    assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0]
+    assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1]
+    assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1]
+
+
+if __name__ == "__main__":
+    test_parse_bpe_timestamps_and_texts()
+    test_parse_timestamps_and_texts()
diff --git a/test/test_utils.py b/test/test_utils.py
index 6a9ce7853..31f06bd51 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -50,9 +50,7 @@ def test_encode_supervisions(sup):
     assert torch.all(
         torch.eq(
             supervision_segments,
-            torch.tensor(
-                [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
-            ),
+            torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]),
         )
     )
     assert texts == ["two", "one", "three"]