diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh index bb7c7dfdc..0bec8c0c4 100755 --- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh +++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh @@ -15,5 +15,5 @@ mkdir -p data cd data [ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank cd .. -./local/compute_fbank_librispeech.py +./local/compute_fbank_librispeech.py --dataset 'test-clean test-other' ls -lh data/fbank/ diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index e70a1848d..4c393f6be 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -25,7 +25,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index df29f188e..c68ccc954 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh deleted file mode 100755 index 32c939206..000000000 --- a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env bash -# -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo -git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -cd exp -ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt -popd - -log "Display test files" -tree $repo/ -soxi $repo/test_wavs/*.wav -ls -lh $repo/test_wavs/*.wav - -log "Install ncnn and pnnx" - -# We are using a modified ncnn here. Will try to merge it to the official repo -# of ncnn -git clone https://github.com/csukuangfj/ncnn -pushd ncnn -git submodule init -git submodule update python/pybind11 -python3 setup.py bdist_wheel -ls -lh dist/ -pip install dist/*.whl -cd tools/pnnx -mkdir build -cd build -cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 .. -make -j4 pnnx - -./src/pnnx || echo "pass" - -popd - -log "Test exporting to pnnx format" - -./conv_emformer_transducer_stateless2/export-for-ncnn.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --num-encoder-layers 12 \ - --chunk-length 32 \ - --cnn-module-kernel 31 \ - --left-context-length 32 \ - --right-context-length 8 \ - --memory-size 32 - -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt - -./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 9b883f889..4cd2c4bec 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -20,7 +20,6 @@ abs_repo=$(realpath $repo) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -28,63 +27,6 @@ ln -s pretrained-iter-468000-avg-16.pt pretrained.pt ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt popd -log "Install ncnn and pnnx" - -# We are using a modified ncnn here. Will try to merge it to the official repo -# of ncnn -git clone https://github.com/csukuangfj/ncnn -pushd ncnn -git submodule init -git submodule update python/pybind11 -python3 setup.py bdist_wheel -ls -lh dist/ -pip install dist/*.whl -cd tools/pnnx -mkdir build -cd build -cmake .. -make -j4 pnnx - -./src/pnnx || echo "pass" - -popd - -log "Test exporting to pnnx format" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --pnnx 1 - -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt -./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt - -./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav - -./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/1089-134686-0001.wav - - - log "Test exporting with torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ @@ -106,47 +48,6 @@ log "Decode with models exported by torch.jit.trace()" $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav -log "Test exporting to ONNX" - -./lstm_transducer_stateless2/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --onnx 1 - -log "Decode with ONNX models " - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1221-135766-0001.wav - -./lstm_transducer_stateless2/streaming-onnx-decode.py \ - --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo//exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1221-135766-0002.wav - - - for sym in 1 2 3; do log "Greedy search with --max-sym-per-frame $sym" diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index dafea56db..6792c7088 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index c3d07dc0e..dbf678d72 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -23,7 +23,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 22de3b45d..b6d477afe 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -22,7 +22,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index 880767443..efa4b53f0 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -27,14 +26,6 @@ ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt popd -log "Test exporting to ONNX format" - -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --onnx 1 log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ @@ -51,30 +42,8 @@ log "Export to torchscript model" --avg 1 \ --jit-trace 1 -ls -lh $repo/exp/*.onnx ls -lh $repo/exp/*.pt -log "Decode with ONNX models" - -./pruned_transducer_stateless3/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx - -./pruned_transducer_stateless3/onnx_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - log "Decode with models exported by torch.jit.trace()" ./pruned_transducer_stateless3/jit_pretrained.py \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index c6a781318..511fe0c9e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 999841b80..2bc179c86 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -30,15 +29,6 @@ ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd -log "Test exporting to ONNX format" -./pruned_transducer_stateless7/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --epoch 99 \ - --avg 1 \ - --onnx 1 - log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ @@ -50,27 +40,6 @@ log "Export to torchscript model" ls -lh $repo/exp/*.pt -log "Decode with ONNX models" - -./pruned_transducer_stateless7/onnx_check.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx - -./pruned_transducer_stateless7/onnx_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - log "Decode with models exported by torch.jit.script()" ./pruned_transducer_stateless7/jit_pretrained.py \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 3cbb480f6..192438353 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -148,4 +147,4 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless7_ctc/exp/*.pt -fi \ No newline at end of file +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh index ed66a728e..761eb72e2 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -10,7 +10,7 @@ log() { cd egs/librispeech/ASR -repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14 +repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index afb0dc05a..e1e4e1f10 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -19,16 +19,16 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav -pushd $repo/exp +pushd $repo git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/cpu_jit.pt" git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/encoder_jit_trace.pt" git lfs pull --include "exp/decoder_jit_trace.pt" git lfs pull --include "exp/joiner_jit_trace.pt" +cd exp ln -s pretrained.pt epoch-99.pt ls -lh *.pt popd diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index e782b8425..5d9485692 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index af37102d5..77cd59506 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index 5b8ed396b..b4aca1b6b 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index 77f28b054..a58b8ec56 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 96c320616..125d1f3b1 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.flac ls -lh $repo/test_wavs/*.flac log "CTC decoding" diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 209d4814f..89115e88d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 34ff76fe4..85e2c89e6 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh index 75650c2d3..0644d9be0 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh index bcc2d74cb..79fb64311 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index d3e40315a..41456f11b 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index cfa006776..1331c966c 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav log "Beam search decoding" diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 2d237dcf2..90097c752 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -20,7 +20,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh new file mode 100755 index 000000000..52491d2ea --- /dev/null +++ b/.github/scripts/test-ncnn-export.sh @@ -0,0 +1,234 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +pushd egs/librispeech/ASR + +log "Install ncnn and pnnx" + +# We are using a modified ncnn here. Will try to merge it to the official repo +# of ncnn +git clone https://github.com/csukuangfj/ncnn +pushd ncnn +git submodule init +git submodule update python/pybind11 +python3 setup.py bdist_wheel +ls -lh dist/ +pip install dist/*.whl +cd tools/pnnx +mkdir build +cd build + +echo "which python3" + +which python3 +#/opt/hostedtoolcache/Python/3.8.16/x64/bin/python3 + +cmake -D Python3_EXECUTABLE=$(which python3) .. +make -j4 pnnx + +./src/pnnx || echo "pass" + +popd + +export PATH=$PWD/ncnn/tools/pnnx/build/src:$PATH + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +python3 ./lstm_transducer_stateless2/ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +pnnx $repo/exp/encoder_jit_trace-pnnx.pt +pnnx $repo/exp/decoder_jit_trace-pnnx.pt +pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/0.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh new file mode 100755 index 000000000..39467c44a --- /dev/null +++ b/.github/scripts/test-onnx-export.sh @@ -0,0 +1,351 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + + + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless3/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless3/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt" + +cd exp +ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless5/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless5/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless5/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url= + +rm -rf $repo +log "--------------------------------------------------------------------------" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" + +./pruned_transducer_stateless7/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --jit 1 + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./pruned_transducer_stateless7/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./pruned_transducer_stateless7/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Test exporting to ONNX format" + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +log "Run onnx_pretrained.py" + +./conv_emformer_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" + +log "==========================================================================" +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Export via torch.jit.trace()" + +./lstm_transducer_stateless2/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit-trace 1 + +log "Test exporting to ONNX format" + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./lstm_transducer_stateless2/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./lstm_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +rm -rf $repo +log "--------------------------------------------------------------------------" diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index 1865a0da8..f5ba73195 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -65,7 +65,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -87,7 +87,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index e438c5dba..c7b9cc79d 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 3ba6850cd..9c7cd1228 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index 595b410b8..78c9e759f 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index eb0b06a2d..04799bf52 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 7694e8bf5..6dfc23920 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml index acb11a8f4..0544e68b3 100644 --- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml index ccd8d50d0..62e1f2a01 100644 --- a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml +++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml @@ -60,7 +60,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +119,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml index 5472ca59b..7dc33aaa9 100644 --- a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml +++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml index 6e2b40cf3..de55847ad 100644 --- a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml +++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml @@ -35,7 +35,7 @@ on: jobs: run_librispeech_2022_12_15_zipformer_ctc_bs: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -60,7 +60,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +119,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml index 6dd93946a..feb5c6fd0 100644 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml index d763fb1c5..c95ed8b9a 100644 --- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml +++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 3752f67e3..e14d4e92f 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -22,7 +22,7 @@ concurrency: jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -47,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -106,7 +106,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index 2c2bcab0c..73d91fcd4 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -39,7 +39,7 @@ concurrency: jobs: run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index ac7e58b20..8a690393e 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index 575727e22..217dbdfa1 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 7dbfd2bd9..4e8e7b8db 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index d6b3de8d4..ddde4f1d6 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index 749fb3fca..00ea97b2a 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 92bf6feb8..b3cfc9efd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index e51da8bd8..ab598541d 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 2103d0510..d663d49dd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 902319b55..9cb9d3b59 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml index 47ed958f2..f8d9c02c5 100644 --- a/.github/workflows/run-ptb-rnn-lm.yml +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -47,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Prepare data shell: bash diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index 8a7be0b80..14fb96ec8 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -76,7 +76,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index ed343aee5..1187dbf38 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -67,7 +67,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Run yesno recipe shell: bash diff --git a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml b/.github/workflows/test-ncnn-export.yml similarity index 73% rename from .github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml rename to .github/workflows/test-ncnn-export.yml index b9a1582c4..cdea54854 100644 --- a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml +++ b/.github/workflows/test-ncnn-export.yml @@ -1,4 +1,4 @@ -name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05 +name: test-ncnn-export on: push: @@ -16,15 +16,18 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: test_ncnn_export-${{ github.ref }} + cancel-in-progress: true + jobs: - run_librispeech_conv_emformer_transducer_stateless2_2022_12_05: + test_ncnn_export: if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] python-version: [3.8] - fail-fast: false steps: @@ -41,9 +44,9 @@ jobs: - name: Install Python dependencies run: | - grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -59,19 +62,14 @@ jobs: run: | .github/scripts/install-kaldifeat.sh - - name: Inference with pre-trained model + - name: Test ncnn export shell: bash env: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree sox export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh + .github/scripts/test-ncnn-export.sh diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml new file mode 100644 index 000000000..3dc4261ab --- /dev/null +++ b/.github/workflows/test-onnx-export.yml @@ -0,0 +1,75 @@ +name: test-onnx-export + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: test_onnx_export-${{ github.ref }} + cancel-in-progress: true + +jobs: + test_onnx_export: + if: github.event.label.name == 'ready' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Test ONNX export + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/test-onnx-export.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c062a2a3d..0da4f6b4b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -56,7 +56,7 @@ jobs: run: | sudo apt update sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all - name: Install Python dependencies run: | @@ -70,7 +70,7 @@ jobs: pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* pip install kaldifst pip install onnxruntime diff --git a/LICENSE b/LICENSE index ee06cfc77..d64569567 100644 --- a/LICENSE +++ b/LICENSE @@ -1,13 +1,4 @@ - Legal Notices - - NOTE (this is not from the Apache License): The copyright model is that - authors (or their employers, if noted in individual files) own their - individual contributions. The authors' contributions can be discerned - from the git history. - - ------------------------------------------------------------------------- - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ diff --git a/docs/source/conf.py b/docs/source/conf.py index ef9fe1445..6901dec02 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -81,9 +81,12 @@ todo_include_todos = True rst_epilog = """ .. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _sherpa-onnx: https://github.com/k2-fsa/sherpa-onnx .. _icefall: https://github.com/k2-fsa/icefall .. _git-lfs: https://git-lfs.com/ .. _ncnn: https://github.com/tencent/ncnn .. _LibriSpeech: https://www.openslr.org/12 .. _musan: http://www.openslr.org/17/ +.. _ONNX: https://github.com/onnx/onnx +.. _onnxruntime: https://github.com/microsoft/onnxruntime """ diff --git a/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..fe4460985 --- /dev/null +++ b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt @@ -0,0 +1,18 @@ +2023-02-17 11:22:42,862 INFO [export-for-ncnn.py:222] device: cpu +2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:231] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'dim_feedforward': 2048, 'decoder_dim': 512, 'joiner_dim': 512, 'is_pnnx': False, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-dirty', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp'), 'bpe_model': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': 12, 'encoder_dim': 512, 'rnn_hidden_size': 1024, 'aux_layer_period': 0, 'blank_id': 0, 'vocab_size': 500} +2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:235] About to create model +2023-02-17 11:22:43,239 INFO [train.py:472] Disable giga +2023-02-17 11:22:43,249 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/epoch-99.pt +2023-02-17 11:22:44,595 INFO [export-for-ncnn.py:324] encoder parameters: 83137520 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:325] decoder parameters: 257024 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:326] joiner parameters: 781812 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:327] total parameters: 84176356 +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:329] Using torch.jit.trace() +2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:331] Exporting encoder +2023-02-17 11:22:48,182 INFO [export-for-ncnn.py:158] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt +2023-02-17 11:22:48,183 INFO [export-for-ncnn.py:335] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/lstm_transducer_stateless2/decoder.py:101: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + need_pad = bool(need_pad) +2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:180] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt +2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:339] Exporting joiner +2023-02-17 11:22:48,304 INFO [export-for-ncnn.py:207] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..25874a414 --- /dev/null +++ b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt @@ -0,0 +1,74 @@ +2023-02-27 20:23:07,473 INFO [export-for-ncnn.py:246] device: cpu +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:255] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-clean', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp'), 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': '2,4,3,2,4', 'feedforward_dims': '1024,1024,2048,2048,1024', 'nhead': '8,8,8,8,8', 'encoder_dims': '384,384,384,384,384', 'attention_dims': '192,192,192,192,192', 'encoder_unmasked_dims': '256,256,256,256,256', 'zipformer_downsampling_factors': '1,2,4,8,2', 'cnn_module_kernels': '31,31,31,31,31', 'decoder_dim': 512, 'joiner_dim': 512, 'short_chunk_size': 50, 'num_left_chunks': 4, 'decode_chunk_len': 32, 'blank_id': 0, 'vocab_size': 500} +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:257] About to create model +2023-02-27 20:23:08,023 INFO [zipformer2.py:419] At encoder stack 4, which has downsampling_factor=2, we will combine the outputs of layers 1 and 3, with downsampling_factors=2 and 8. +2023-02-27 20:23:08,037 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/epoch-99.pt +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:346] encoder parameters: 68944004 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:347] decoder parameters: 260096 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:348] joiner parameters: 716276 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:349] total parameters: 69920376 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:351] Using torch.jit.trace() +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:353] Exporting encoder +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:174] decode_chunk_len: 32 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:175] T: 39 +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1344: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_len.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_avg.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1352: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1356: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1360: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1364: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv1.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1368: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1373: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.left_context_len == cached_key.shape[1], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1884: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.x_size == x.size(0), (self.x_size, x.size(0)) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2442: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2449: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == cached_val.shape[0], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2469: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2473: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2483: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert kv_len == k.shape[0], (kv_len, k.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2570: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2926: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2652: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2653: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2666: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1543: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1637: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1643: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1571: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + if src.shape[0] != self.in_x_size: +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1763: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1779: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1780: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) +/star-fj/fangjun/py38/lib/python3.8/site-packages/torch/jit/_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior. + module._c._create_method_from_trace( +2023-02-27 20:23:19,640 INFO [export-for-ncnn.py:182] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,646 INFO [export-for-ncnn.py:357] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py:102: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert embedding_out.size(-1) == self.context_size +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:204] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:361] Exporting joiner +2023-02-27 20:23:19,735 INFO [export-for-ncnn.py:231] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt new file mode 100644 index 000000000..d39215b14 --- /dev/null +++ b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt @@ -0,0 +1,44 @@ +Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1 +num encoder conv layers: 28 +num joiner conv layers: 3 +num files: 3 +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +Processing ../test_wavs/1089-134686-0001.wav +Processing ../test_wavs/1221-135766-0001.wav +Processing ../test_wavs/1221-135766-0002.wav +----------encoder---------- +conv_15 : max = 15.942385 threshold = 15.930708 scale = 7.972025 +conv_16 : max = 44.978855 threshold = 17.031788 scale = 7.456645 +conv_17 : max = 17.868437 threshold = 7.830528 scale = 16.218575 +linear_18 : max = 3.107259 threshold = 1.194808 scale = 106.293236 +linear_19 : max = 6.193777 threshold = 4.634748 scale = 27.401705 +linear_20 : max = 9.259933 threshold = 2.606617 scale = 48.722160 +linear_21 : max = 5.186600 threshold = 4.790260 scale = 26.512129 +linear_22 : max = 9.759041 threshold = 2.265832 scale = 56.050053 +linear_23 : max = 3.931209 threshold = 3.099090 scale = 40.979767 +linear_24 : max = 10.324160 threshold = 2.215561 scale = 57.321835 +linear_25 : max = 3.800708 threshold = 3.599352 scale = 35.284134 +linear_26 : max = 10.492444 threshold = 3.153369 scale = 40.274391 +linear_27 : max = 3.660161 threshold = 2.720994 scale = 46.674126 +linear_28 : max = 9.415265 threshold = 3.174434 scale = 40.007133 +linear_29 : max = 4.038418 threshold = 3.118534 scale = 40.724262 +linear_30 : max = 10.072084 threshold = 3.936867 scale = 32.259155 +linear_31 : max = 4.342712 threshold = 3.599489 scale = 35.282787 +linear_32 : max = 11.340535 threshold = 3.120308 scale = 40.701103 +linear_33 : max = 3.846987 threshold = 3.630030 scale = 34.985939 +linear_34 : max = 10.686298 threshold = 2.204571 scale = 57.607586 +linear_35 : max = 4.904821 threshold = 4.575518 scale = 27.756420 +linear_36 : max = 11.806659 threshold = 2.585589 scale = 49.118401 +linear_37 : max = 6.402340 threshold = 5.047157 scale = 25.162680 +linear_38 : max = 11.174589 threshold = 1.923361 scale = 66.030258 +linear_39 : max = 16.178576 threshold = 7.556058 scale = 16.807705 +linear_40 : max = 12.901954 threshold = 5.301267 scale = 23.956539 +linear_41 : max = 14.839805 threshold = 7.597429 scale = 16.716181 +linear_42 : max = 10.178945 threshold = 2.651595 scale = 47.895699 +----------joiner---------- +linear_2 : max = 24.829245 threshold = 16.627592 scale = 7.637907 +linear_1 : max = 10.746186 threshold = 5.255032 scale = 24.167313 +linear_3 : max = 1.000000 threshold = 0.999756 scale = 127.031013 +ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233... diff --git a/docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt similarity index 100% rename from docs/source/model-export/code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt rename to docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt new file mode 100644 index 000000000..3606eae3d --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt @@ -0,0 +1,6 @@ +2023-02-17 11:37:30,861 INFO [streaming-ncnn-decode.py:255] {'tokens': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav'} +2023-02-17 11:37:31,425 INFO [streaming-ncnn-decode.py:263] Constructing Fbank computer +2023-02-17 11:37:31,427 INFO [streaming-ncnn-decode.py:266] Reading sound files: ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav +2023-02-17 11:37:31,431 INFO [streaming-ncnn-decode.py:271] torch.Size([106000]) +2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:342] ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav +2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:343] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt new file mode 100644 index 000000000..5b4969e0f --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-02-27 20:43:40,283 INFO [streaming-ncnn-decode.py:349] {'tokens': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav'} +2023-02-27 20:43:41,260 INFO [streaming-ncnn-decode.py:357] Constructing Fbank computer +2023-02-27 20:43:41,264 INFO [streaming-ncnn-decode.py:360] Reading sound files: ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:41,269 INFO [streaming-ncnn-decode.py:365] torch.Size([106000]) +2023-02-27 20:43:41,280 INFO [streaming-ncnn-decode.py:372] number of states: 35 +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:410] ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:411] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst new file mode 100644 index 000000000..12b370143 --- /dev/null +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -0,0 +1,753 @@ +.. _export_conv_emformer_transducer_models_to_ncnn: + +Export ConvEmformer transducer models to ncnn +============================================= + +We use the pre-trained model from the following repository as an example: + + - ``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +--------------------------------- + +.. hint:: + + You can also refer to ``_ to download the pre-trained model. + + You have to install `git-lfs`_ before you continue. + +.. code-block:: bash + + cd egs/librispeech/ASR + + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + + git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. + +.. _export_for_ncnn_install_ncnn_and_pnnx: + +2. Install ncnn and pnnx +------------------------ + +.. code-block:: bash + + # We put ncnn into $HOME/open-source/ncnn + # You can change it to anywhere you like + + cd $HOME + mkdir -p open-source + cd open-source + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + + # Note: We don't use "python setup.py install" or "pip install ." here + + mkdir -p build-wheel + cd build-wheel + + cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DNCNN_PYTHON=ON \ + -DNCNN_BUILD_BENCHMARK=OFF \ + -DNCNN_BUILD_EXAMPLES=OFF \ + -DNCNN_BUILD_TOOLS=ON \ + .. + + make -j4 + + cd .. + + # Note: $PWD here is $HOME/open-source/ncnn + + export PYTHONPATH=$PWD/python:$PYTHONPATH + export PATH=$PWD/tools/pnnx/build/src:$PATH + export PATH=$PWD/build-wheel/tools/quantize:$PATH + + # Now build pnnx + cd tools/pnnx + mkdir build + cd build + cmake .. + make -j4 + + ./src/pnnx + +Congratulations! You have successfully installed the following components: + + - ``pnnx``, which is an executable located in + ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use + it to convert models exported by ``torch.jit.trace()``. + - ``ncnn2int8``, which is an executable located in + ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use + it to quantize our models to ``int8``. + - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located + in ``$HOME/open-source/ncnn/python/ncnn``. + + .. note:: + + I am using ``Python 3.8``, so it + is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different + version, say, ``Python 3.9``, the name would be + ``ncnn.cpython-39-x86_64-linux-gnu.so``. + + Also, if you are not using Linux, the file name would also be different. + But that does not matter. As long as you can compile it, it should work. + +We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your +Python code. We have also set up ``PATH`` so that you can use +``pnnx`` and ``ncnn2int8`` later in your terminal. + +.. caution:: + + Please don't use ``_. + We have made some modifications to the offical `ncnn`_. + + We will synchronize ``_ periodically + with the official one. + +3. Export the model via torch.jit.trace() +----------------------------------------- + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp + + ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ + + ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + --encoder-dim 512 + +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + +.. hint:: + + We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt + + The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt + + You can see that the file size of the pre-trained model is ``289 MB``, which + is roughly equal to ``75490012*4/1024/1024 = 287.97 MB``. + +After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* + + -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt + + +.. _conv-emformer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +------------------------------------ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 283 MB vs 142 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +-------------------------------------- + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + + +.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +---------------------------------------------- + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``1060 1342``, the first number ``1060`` specifies the number of layers + in this file, while ``1342`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. + We don't need to change ``1342`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``2=32``, 2 is the key and 32 is the value of the + parameter ``--memory-size`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``3=31``, 3 is the key and 31 is the value of the + parameter ``--cnn-module-kernel`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``4=8``, 4 is the key and 8 is the value of the + parameter ``--left-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``5=32``, 5 is the key and 32 is the value of the + parameter ``--chunk-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``6=8``, 6 is the key and 8 is the value of the + parameter ``--right-context-length`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + - ``7=512``, 7 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 1 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--memory-size`` | + +------+-----------------------------+ + | 3 | ``--cnn-module-kernel`` | + +------+-----------------------------+ + | 4 | ``--left-context-length`` | + +------+-----------------------------+ + | 5 | ``--chunk-length`` | + +------+-----------------------------+ + | 6 | ``--right-context-length`` | + +------+-----------------------------+ + | 7 | ``--encoder-dim`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``1060`` to ``1061``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +7. (Optional) int8 quantization with sherpa-ncnn +------------------------------------------------ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`conv-emformer-step-4-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not + support quantizing the decoder model yet. We will update this documentation + once `ncnn`_ supports it. (Maybe in this year, 2023). + +It will generate the following files + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 1060 1342 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 1061 1342 + SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository ``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt + +It generates the following two files: + +.. code-block:: bash + + $ ls -lh encoder-scale-table.txt joiner-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: bash + + -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + +The following table compares again the file sizes: + + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 283 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file sizes of the model after ``int8`` quantization +are much smaller. + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + +You can find the speed comparison at ``_. + + +That's it! Have fun with `sherpa-ncnn`_! diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst new file mode 100644 index 000000000..8e6dc7466 --- /dev/null +++ b/docs/source/model-export/export-ncnn-lstm.rst @@ -0,0 +1,644 @@ +.. _export_lstm_transducer_models_to_ncnn: + +Export LSTM transducer models to ncnn +------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp + + ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + ./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $dir/exp \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --num-encoder-layers 12 \ + --encoder-dim 512 \ + --rnn-hidden-size 1024 + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-lstm-transducer-for-ncnn-output.txt + + The log shows the model has ``84176356`` parameters, i.e., ``~84 M``. + + .. code-block:: + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt + + -rw-r--r-- 1 kuangfangjun root 324M Feb 17 10:34 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt + + You can see that the file size of the pre-trained model is ``324 MB``, which + is roughly equal to ``84176356*4/1024/1024 = 321.107 MB``. + +After running ``lstm_transducer_stateless2/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1010K Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 318M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt + + +.. _lstm-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 159M Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.5M Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param + + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 159 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 318 MB vs 159 MB + - decoder: 1010 KB vs 503 KB + - joiner: 3.0 MB vs 1.5 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-lstm-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _lstm-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 267 379 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``267 379``, the first number ``267`` specifies the number of layers + in this file, while ``379`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 268 379 + SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``268 379``, we have added an extra layer, so we need to update ``267`` to ``268``. + We don't need to change ``379`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=3``, 0 is the key and 3 is the value. MUST be ``0=3`` + - ``1=12``, 1 is the key and 12 is the value of the + parameter ``--num-encoder-layers`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + - ``2=512``, 2 is the key and 512 is the value of the + parameter ``--encoder-dim`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + - ``3=1024``, 3 is the key and 1024 is the value of the + parameter ``--rnn-hidden-size`` that you provided when running + ``./lstm_transducer_stateless2/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +------+-----------------------------+ + | key | value | + +======+=============================+ + | 0 | 3 (fixed) | + +------+-----------------------------+ + | 1 | ``--num-encoder-layers`` | + +------+-----------------------------+ + | 2 | ``--encoder-dim`` | + +------+-----------------------------+ + | 3 | ``--rnn-hidden-size`` | + +------+-----------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``267`` to ``268``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. + +7. (Optional) int8 quantization with sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This step is optional. + +In this step, we describe how to quantize our model with ``int8``. + +Change :ref:`lstm-transducer-step-4-export-torchscript-model-via-pnnx` to +disable ``fp16`` when using ``pnnx``: + +.. code-block:: + + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt fp16=0 + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt fp16=0 + +.. note:: + + We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not + support quantizing the decoder model yet. We will update this documentation + once `ncnn`_ supports it. (Maybe in this year, 2023). + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*_jit_trace-pnnx.ncnn.{param,bin} + + -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 317M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param + + +Let us compare again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ + +You can see that the file sizes are doubled when we disable ``fp16``. + +.. note:: + + You can again use ``streaming-ncnn-decode.py`` to test the exported models. + +Next, follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn` +to modify ``encoder_jit_trace-pnnx.ncnn.param``. + +Change + +.. code-block:: bash + + 7767517 + 267 379 + Input in0 0 1 in0 + +to + +.. code-block:: bash + + 7767517 + 268 379 + SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024 + Input in0 0 1 in0 + +.. caution:: + + Please follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn` + to change the values for ``SherpaMetaData`` if your model uses a different setting. + +Next, let us compile `sherpa-ncnn`_ since we will quantize our models within +`sherpa-ncnn`_. + +.. code-block:: bash + + # We will download sherpa-ncnn to $HOME/open-source/ + # You can change it to anywhere you like. + cd $HOME + mkdir -p open-source + + cd open-source + git clone https://github.com/k2-fsa/sherpa-ncnn + cd sherpa-ncnn + mkdir build + cd build + cmake .. + make -j 4 + + ./bin/generate-int8-scale-table + + export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH + +The output of the above commands are: + +.. code-block:: bash + + (py38) kuangfangjun:build$ generate-int8-scale-table + Please provide 10 arg. Currently given: 1 + Usage: + generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt + + Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. + +We need to create a file ``wave_filenames.txt``, in which we need to put +some calibration wave files. For testing purpose, we put the ``test_wavs`` +from the pre-trained model repository +``_ + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + cat < wave_filenames.txt + ../test_wavs/1089-134686-0001.wav + ../test_wavs/1221-135766-0001.wav + ../test_wavs/1221-135766-0002.wav + EOF + +Now we can calculate the scales needed for quantization with the calibration data: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + generate-int8-scale-table \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./encoder-scale-table.txt \ + ./joiner-scale-table.txt \ + ./wave_filenames.txt + +The output logs are in the following: + +.. literalinclude:: ./code/generate-int-8-scale-table-for-lstm.txt + +It generates the following two files: + +.. code-block:: bash + + ls -lh encoder-scale-table.txt joiner-scale-table.txt + + -rw-r--r-- 1 kuangfangjun root 345K Feb 17 12:13 encoder-scale-table.txt + -rw-r--r-- 1 kuangfangjun root 17K Feb 17 12:13 joiner-scale-table.txt + +.. caution:: + + Definitely, you need more calibration data to compute the scale table. + +Finally, let us use the scale table to quantize our models into ``int8``. + +.. code-block:: bash + + ncnn2int8 + + usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] + +First, we quantize the encoder model: + +.. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + ncnn2int8 \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./encoder-scale-table.txt + +Next, we quantize the joiner model: + +.. code-block:: bash + + ncnn2int8 \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.int8.param \ + ./joiner_jit_trace-pnnx.ncnn.int8.bin \ + ./joiner-scale-table.txt + +The above two commands generate the following 4 files: + +.. code-block:: + + -rw-r--r-- 1 kuangfangjun root 218M Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 21K Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.param + -rw-r--r-- 1 kuangfangjun root 774K Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.bin + -rw-r--r-- 1 kuangfangjun root 496 Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.param + +Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. + +.. caution:: + + ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. + + You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` + and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. + + For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can + replace the following invocation: + + .. code-block:: + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.param \ + ./encoder_jit_trace-pnnx.ncnn.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + + with + + .. code-block:: bash + + cd egs/librispeech/ASR + cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ + + sherpa-ncnn \ + ../data/lang_bpe_500/tokens.txt \ + ./encoder_jit_trace-pnnx.ncnn.int8.param \ + ./encoder_jit_trace-pnnx.ncnn.int8.bin \ + ./decoder_jit_trace-pnnx.ncnn.param \ + ./decoder_jit_trace-pnnx.ncnn.bin \ + ./joiner_jit_trace-pnnx.ncnn.param \ + ./joiner_jit_trace-pnnx.ncnn.bin \ + ../test_wavs/1089-134686-0001.wav + +The following table compares again the file sizes: + ++----------------------------------------+------------+ +| File name | File size | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.pt | 318 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1010 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB | ++----------------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | ++----------------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.int8.bin | 218 MB | ++----------------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | ++----------------------------------------+------------+ + +You can see that the file size of the joiner model after ``int8`` quantization +is much smaller. However, the size of the encoder model is even larger than +the ``fp16`` counterpart. The reason is that `ncnn`_ currently does not support +quantizing ``LSTM`` layers into ``8-bit``. Please see +``_ + +.. hint:: + + Currently, only linear layers and convolutional layers are quantized + with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. + +.. note:: + + You need to test the recognition accuracy after ``int8`` quantization. + + +That's it! Have fun with `sherpa-ncnn`_! diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst new file mode 100644 index 000000000..5c81d25ca --- /dev/null +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -0,0 +1,383 @@ +.. _export_streaming_zipformer_transducer_models_to_ncnn: + +Export streaming Zipformer transducer models to ncnn +---------------------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + + ln -s pretrained.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --exp-dir $dir/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-zipformer-transducer-for-ncnn-output.txt + + The log shows the model has ``69920376`` parameters, i.e., ``~69.9 M``. + + .. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + -rw-r--r-- 1 kuangfangjun root 269M Jan 12 12:53 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + + You can see that the file size of the pre-trained model is ``269 MB``, which + is roughly equal to ``69920376*4/1024/1024 = 266.725 MB``. + +After running ``pruned_transducer_stateless7_streaming/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1022K Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 266M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 2.8M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt + +.. _zipformer-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 509K Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 133M Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 152K Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.4M Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 266 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1022 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 2.8 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 133 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 509 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.4 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 266 MB vs 133 MB + - decoder: 1022 KB vs 509 KB + - joiner: 2.8 MB vs 1.4 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _zipformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 2028 2547 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``2028 2547``, the first number ``2028`` specifies the number of layers + in this file, while ``2547`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 2029 2547 + SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``2029 2547``, we have added an extra layer, so we need to update ``2028`` to ``2029``. + We don't need to change ``2547`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=2``, 0 is the key and 2 is the value. MUST be ``0=2`` + - ``1=32``, 1 is the key and 32 is the value of the + parameter ``--decode-chunk-len`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``2=4``, 2 is the key and 4 is the value of the + parameter ``--num-left-chunks`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``3=7``, 3 is the key and 7 is the value of for the amount of padding + used in the Conv2DSubsampling layer. It should be 7 for zipformer + if you don't change zipformer.py. + - ``-23316=5,2,4,3,2,4``, attribute 16, this is an array attribute. + It is attribute 16 since -23300 - (-23316) = 16. + The first element of the array is the length of the array, which is 5 in our case. + ``2,4,3,2,4`` is the value of ``--num-encoder-layers``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23317=5,384,384,384,384,384``, attribute 17. + The first element of the array is the length of the array, which is 5 in our case. + ``384,384,384,384,384`` is the value of ``--encoder-dims``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23318=5,192,192,192,192,192``, attribute 18. + The first element of the array is the length of the array, which is 5 in our case. + ``192,192,192,192,192`` is the value of ``--attention-dims`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23319=5,1,2,4,8,2``, attribute 19. + The first element of the array is the length of the array, which is 5 in our case. + ``1,2,4,8,2`` is the value of ``--zipformer-downsampling-factors`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23320=5,31,31,31,31,31``, attribute 20. + The first element of the array is the length of the array, which is 5 in our case. + ``31,31,31,31,31`` is the value of ``--cnn-module-kernels`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +----------+--------------------------------------------+ + | key | value | + +==========+============================================+ + | 0 | 2 (fixed) | + +----------+--------------------------------------------+ + | 1 | ``-decode-chunk-len`` | + +----------+--------------------------------------------+ + | 2 | ``--num-left-chunks`` | + +----------+--------------------------------------------+ + | 3 | 7 (if you don't change code) | + +----------+--------------------------------------------+ + |-23316 | ``--num-encoder-layer`` | + +----------+--------------------------------------------+ + |-23317 | ``--encoder-dims`` | + +----------+--------------------------------------------+ + |-23318 | ``--attention-dims`` | + +----------+--------------------------------------------+ + |-23319 | ``--zipformer-downsampling-factors`` | + +----------+--------------------------------------------+ + |-23320 | ``--cnn-module-kernels`` | + +----------+--------------------------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``2028`` to ``2029``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index ed0264089..9eb5f85d2 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,15 +1,27 @@ Export to ncnn ============== -We support exporting both -`LSTM transducer models `_ -and -`ConvEmformer transducer models `_ -to `ncnn `_. +We support exporting the following models +to `ncnn `_: -We also provide ``_ -performing speech recognition using ``ncnn`` with exported models. -It has been tested on Linux, macOS, Windows, ``Android``, and ``Raspberry Pi``. + - `Zipformer transducer models `_ + + - `LSTM transducer models `_ + + - `ConvEmformer transducer models `_ + +We also provide `sherpa-ncnn`_ +for performing speech recognition using `ncnn`_ with exported models. +It has been tested on the following platforms: + + - Linux + - macOS + - Windows + - ``Android`` + - ``iOS`` + - ``Raspberry Pi`` + - `爱芯派 `_ (`MAIX-III AXera-Pi `_). + - `RV1126 `_ `sherpa-ncnn`_ is self-contained and can be statically linked to produce a binary containing everything needed. Please refer @@ -18,754 +30,8 @@ to its documentation for details: - ``_ -Export LSTM transducer models ------------------------------ +.. toctree:: -Please refer to :ref:`export-lstm-transducer-model-for-ncnn` for details. - - - -Export ConvEmformer transducer models -------------------------------------- - -We use the pre-trained model from the following repository as an example: - - - ``_ - -We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. - -.. hint:: - - We use ``Ubuntu 18.04``, ``torch 1.10``, and ``Python 3.8`` for testing. - -.. caution:: - - Please use a more recent version of PyTorch. For instance, ``torch 1.8`` - may ``not`` work. - -1. Download the pre-trained model -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. hint:: - - You can also refer to ``_ to download the pre-trained model. - - You have to install `git-lfs`_ before you continue. - -.. code-block:: bash - - cd egs/librispeech/ASR - - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 - - git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" - git lfs pull --include "data/lang_bpe_500/bpe.model" - - cd .. - -.. note:: - - We download ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. - - -In the above code, we download the pre-trained model into the directory -``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``. - -2. Install ncnn and pnnx -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: bash - - # We put ncnn into $HOME/open-source/ncnn - # You can change it to anywhere you like - - cd $HOME - mkdir -p open-source - cd open-source - - git clone https://github.com/csukuangfj/ncnn - cd ncnn - git submodule update --recursive --init - - # Note: We don't use "python setup.py install" or "pip install ." here - - mkdir -p build-wheel - cd build-wheel - - cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DNCNN_PYTHON=ON \ - -DNCNN_BUILD_BENCHMARK=OFF \ - -DNCNN_BUILD_EXAMPLES=OFF \ - -DNCNN_BUILD_TOOLS=ON \ - .. - - make -j4 - - cd .. - - # Note: $PWD here is $HOME/open-source/ncnn - - export PYTHONPATH=$PWD/python:$PYTHONPATH - export PATH=$PWD/tools/pnnx/build/src:$PATH - export PATH=$PWD/build-wheel/tools/quantize:$PATH - - # Now build pnnx - cd tools/pnnx - mkdir build - cd build - cmake .. - make -j4 - - ./src/pnnx - -Congratulations! You have successfully installed the following components: - - - ``pnxx``, which is an executable located in - ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use - it to convert models exported by ``torch.jit.trace()``. - - ``ncnn2int8``, which is an executable located in - ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use - it to quantize our models to ``int8``. - - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located - in ``$HOME/open-source/ncnn/python/ncnn``. - - .. note:: - - I am using ``Python 3.8``, so it - is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different - version, say, ``Python 3.9``, the name would be - ``ncnn.cpython-39-x86_64-linux-gnu.so``. - - Also, if you are not using Linux, the file name would also be different. - But that does not matter. As long as you can compile it, it should work. - -We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your -Python code. We have also set up ``PATH`` so that you can use -``pnnx`` and ``ncnn2int8`` later in your terminal. - -.. caution:: - - Please don't use ``_. - We have made some modifications to the offical `ncnn`_. - - We will synchronize ``_ periodically - with the official one. - -3. Export the model via torch.jit.trace() -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -First, let us rename our pre-trained model: - -.. code-block:: - - cd egs/librispeech/ASR - - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp - - ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt - - cd ../.. - -Next, we use the following code to export our model: - -.. code-block:: bash - - dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/ - - ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ - --exp-dir $dir/exp \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --num-encoder-layers 12 \ - --chunk-length 32 \ - --cnn-module-kernel 31 \ - --left-context-length 32 \ - --right-context-length 8 \ - --memory-size 32 \ - --encoder-dim 512 - -.. hint:: - - We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. - There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. - - If you have trained a model by yourself and if you have all checkpoints - available, please first use ``decode.py`` to tune ``--epoch --avg`` - and select the best combination with with ``--use-averaged-model 1``. - -.. note:: - - You will see the following log output: - - .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt - - The log shows the model has ``75490012`` parameters, i.e., ``~75 M``. - - .. code-block:: - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt - - -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt - - You can see that the file size of the pre-trained model is ``289 MB``, which - is roughly ``75490012*4/1024/1024 = 287.97 MB``. - -After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``, -we will get the following files: - -.. code-block:: bash - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx* - - -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt - -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt - -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt - - -.. _conv-emformer-step-3-export-torchscript-model-via-pnnx: - -3. Export torchscript model via pnnx -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. hint:: - - Make sure you have set up the ``PATH`` environment variable. Otherwise, - it will throw an error saying that ``pnnx`` could not be found. - -Now, it's time to export our models to `ncnn`_ via ``pnnx``. - -.. code-block:: - - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - pnnx ./encoder_jit_trace-pnnx.pt - pnnx ./decoder_jit_trace-pnnx.pt - pnnx ./joiner_jit_trace-pnnx.pt - -It will generate the following files: - -.. code-block:: bash - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param} - - -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param - -There are two types of files: - -- ``param``: It is a text file containing the model architectures. You can - use a text editor to view its content. -- ``bin``: It is a binary file containing the model parameters. - -We compare the file sizes of the models below before and after converting via ``pnnx``: - -.. see https://tableconvert.com/restructuredtext-generator - -+----------------------------------+------------+ -| File name | File size | -+==================================+============+ -| encoder_jit_trace-pnnx.pt | 283 MB | -+----------------------------------+------------+ -| decoder_jit_trace-pnnx.pt | 1010 KB | -+----------------------------------+------------+ -| joiner_jit_trace-pnnx.pt | 3.0 MB | -+----------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin | 142 MB | -+----------------------------------+------------+ -| decoder_jit_trace-pnnx.ncnn.bin | 503 KB | -+----------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB | -+----------------------------------+------------+ - -You can see that the file sizes of the models after conversion are about one half -of the models before conversion: - - - encoder: 283 MB vs 142 MB - - decoder: 1010 KB vs 503 KB - - joiner: 3.0 MB vs 1.5 MB - -The reason is that by default ``pnnx`` converts ``float32`` parameters -to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes -for ``float16``. Thus, it is ``twice smaller`` after conversion. - -.. hint:: - - If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` - won't convert ``float32`` to ``float16``. - -4. Test the exported models in icefall -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. note:: - - We assume you have set up the environment variable ``PYTHONPATH`` when - building `ncnn`_. - -Now we have successfully converted our pre-trained model to `ncnn`_ format. -The generated 6 files are what we need. You can use the following code to -test the converted models: - -.. code-block:: bash - - ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ - --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \ - --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \ - ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav - -.. hint:: - - `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts - only 1 wave file as input. - -The output is given below: - -.. literalinclude:: ./code/test-stremaing-ncnn-decode-conv-emformer-transducer-libri.txt - -Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! - - -.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn: - -5. Modify the exported encoder for sherpa-ncnn -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In order to use the exported models in `sherpa-ncnn`_, we have to modify -``encoder_jit_trace-pnnx.ncnn.param``. - -Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: - -.. code-block:: - - 7767517 - 1060 1342 - Input in0 0 1 in0 - -**Explanation** of the above three lines: - - 1. ``7767517``, it is a magic number and should not be changed. - 2. ``1060 1342``, the first number ``1060`` specifies the number of layers - in this file, while ``1342`` specifies the number of intermediate outputs - of this file - 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` - is the layer name of this layer; ``0`` means this layer has no input; - ``1`` means this layer has one output; ``in0`` is the output name of - this layer. - -We need to add 1 extra line and also increment the number of layers. -The result looks like below: - -.. code-block:: bash - - 7767517 - 1061 1342 - SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 - Input in0 0 1 in0 - -**Explanation** - - 1. ``7767517``, it is still the same - 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``. - We don't need to change ``1342`` since the newly added layer has no inputs or outputs. - 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512`` - This line is newly added. Its explanation is given below: - - - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. - - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. - - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` - - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1`` - - ``1=12``, 1 is the key and 12 is the value of the - parameter ``--num-encoder-layers`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``2=32``, 2 is the key and 32 is the value of the - parameter ``--memory-size`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``3=31``, 3 is the key and 31 is the value of the - parameter ``--cnn-module-kernel`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``4=8``, 4 is the key and 8 is the value of the - parameter ``--left-context-length`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``5=32``, 5 is the key and 32 is the value of the - parameter ``--chunk-length`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``6=8``, 6 is the key and 8 is the value of the - parameter ``--right-context-length`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - ``7=512``, 7 is the key and 512 is the value of the - parameter ``--encoder-dim`` that you provided when running - ``conv_emformer_transducer_stateless2/export-for-ncnn.py``. - - For ease of reference, we list the key-value pairs that you need to add - in the following table. If your model has a different setting, please - change the values for ``SherpaMetaData`` accordingly. Otherwise, you - will be ``SAD``. - - +------+-----------------------------+ - | key | value | - +======+=============================+ - | 0 | 1 (fixed) | - +------+-----------------------------+ - | 1 | ``--num-encoder-layers`` | - +------+-----------------------------+ - | 2 | ``--memory-size`` | - +------+-----------------------------+ - | 3 | ``--cnn-module-kernel`` | - +------+-----------------------------+ - | 4 | ``--left-context-length`` | - +------+-----------------------------+ - | 5 | ``--chunk-length`` | - +------+-----------------------------+ - | 6 | ``--right-context-length`` | - +------+-----------------------------+ - | 7 | ``--encoder-dim`` | - +------+-----------------------------+ - - 4. ``Input in0 0 1 in0``. No need to change it. - -.. caution:: - - When you add a new layer ``SherpaMetaData``, please remember to update the - number of layers. In our case, update ``1060`` to ``1061``. Otherwise, - you will be SAD later. - -.. hint:: - - After adding the new layer ``SherpaMetaData``, you cannot use this model - with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is - supported only in `sherpa-ncnn`_. - -.. hint:: - - `ncnn`_ is very flexible. You can add new layers to it just by text-editing - the ``param`` file! You don't need to change the ``bin`` file. - -Now you can use this model in `sherpa-ncnn`_. -Please refer to the following documentation: - - - Linux/macOS/Windows/arm/aarch64: ``_ - - Android: ``_ - - Python: ``_ - -We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: - - - ``_ - - You can find more usages there. - -6. (Optional) int8 quantization with sherpa-ncnn -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This step is optional. - -In this step, we describe how to quantize our model with ``int8``. - -Change :ref:`conv-emformer-step-3-export-torchscript-model-via-pnnx` to -disable ``fp16`` when using ``pnnx``: - -.. code-block:: - - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - pnnx ./encoder_jit_trace-pnnx.pt fp16=0 - pnnx ./decoder_jit_trace-pnnx.pt - pnnx ./joiner_jit_trace-pnnx.pt fp16=0 - -.. note:: - - We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not - support quantizing the decoder model yet. We will update this documentation - once `ncnn`_ supports it. (Maybe in this year, 2023). - -It will generate the following files - -.. code-block:: bash - - ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin} - - -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param - -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin - -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param - -Let us compare again the file sizes: - -+----------------------------------------+------------+ -| File name | File size | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.pt | 283 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.pt | 1010 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.pt | 3.0 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | -+----------------------------------------+------------+ - -You can see that the file sizes are doubled when we disable ``fp16``. - -.. note:: - - You can again use ``streaming-ncnn-decode.py`` to test the exported models. - -Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` -to modify ``encoder_jit_trace-pnnx.ncnn.param``. - -Change - -.. code-block:: bash - - 7767517 - 1060 1342 - Input in0 0 1 in0 - -to - -.. code-block:: bash - - 7767517 - 1061 1342 - SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512 - Input in0 0 1 in0 - -.. caution:: - - Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn` - to change the values for ``SherpaMetaData`` if your model uses a different setting. - - -Next, let us compile `sherpa-ncnn`_ since we will quantize our models within -`sherpa-ncnn`_. - -.. code-block:: bash - - # We will download sherpa-ncnn to $HOME/open-source/ - # You can change it to anywhere you like. - cd $HOME - mkdir -p open-source - - cd open-source - git clone https://github.com/k2-fsa/sherpa-ncnn - cd sherpa-ncnn - mkdir build - cd build - cmake .. - make -j 4 - - ./bin/generate-int8-scale-table - - export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH - -The output of the above commands are: - -.. code-block:: bash - - (py38) kuangfangjun:build$ generate-int8-scale-table - Please provide 10 arg. Currently given: 1 - Usage: - generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt - - Each line in wave_filenames.txt is a path to some 16k Hz mono wave file. - -We need to create a file ``wave_filenames.txt``, in which we need to put -some calibration wave files. For testing purpose, we put the ``test_wavs`` -from the pre-trained model repository ``_ - -.. code-block:: bash - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - cat < wave_filenames.txt - ../test_wavs/1089-134686-0001.wav - ../test_wavs/1221-135766-0001.wav - ../test_wavs/1221-135766-0002.wav - EOF - -Now we can calculate the scales needed for quantization with the calibration data: - -.. code-block:: bash - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - generate-int8-scale-table \ - ./encoder_jit_trace-pnnx.ncnn.param \ - ./encoder_jit_trace-pnnx.ncnn.bin \ - ./decoder_jit_trace-pnnx.ncnn.param \ - ./decoder_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ./encoder-scale-table.txt \ - ./joiner-scale-table.txt \ - ./wave_filenames.txt - -The output logs are in the following: - -.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt - -It generates the following two files: - -.. code-block:: bash - - $ ls -lh encoder-scale-table.txt joiner-scale-table.txt - -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt - -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt - -.. caution:: - - Definitely, you need more calibration data to compute the scale table. - -Finally, let us use the scale table to quantize our models into ``int8``. - -.. code-block:: bash - - ncnn2int8 - - usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table] - -First, we quantize the encoder model: - -.. code-block:: bash - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - ncnn2int8 \ - ./encoder_jit_trace-pnnx.ncnn.param \ - ./encoder_jit_trace-pnnx.ncnn.bin \ - ./encoder_jit_trace-pnnx.ncnn.int8.param \ - ./encoder_jit_trace-pnnx.ncnn.int8.bin \ - ./encoder-scale-table.txt - -Next, we quantize the joiner model: - -.. code-block:: bash - - ncnn2int8 \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.int8.param \ - ./joiner_jit_trace-pnnx.ncnn.int8.bin \ - ./joiner-scale-table.txt - -The above two commands generate the following 4 files: - -.. code-block:: bash - - -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin - -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param - -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin - -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param - -Congratulations! You have successfully quantized your model from ``float32`` to ``int8``. - -.. caution:: - - ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs. - - You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param`` - and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like. - - For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can - replace the following invocation: - - .. code-block:: - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - sherpa-ncnn \ - ../data/lang_bpe_500/tokens.txt \ - ./encoder_jit_trace-pnnx.ncnn.param \ - ./encoder_jit_trace-pnnx.ncnn.bin \ - ./decoder_jit_trace-pnnx.ncnn.param \ - ./decoder_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ../test_wavs/1089-134686-0001.wav - - with - - .. code-block:: - - cd egs/librispeech/ASR - cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/ - - sherpa-ncnn \ - ../data/lang_bpe_500/tokens.txt \ - ./encoder_jit_trace-pnnx.ncnn.int8.param \ - ./encoder_jit_trace-pnnx.ncnn.int8.bin \ - ./decoder_jit_trace-pnnx.ncnn.param \ - ./decoder_jit_trace-pnnx.ncnn.bin \ - ./joiner_jit_trace-pnnx.ncnn.param \ - ./joiner_jit_trace-pnnx.ncnn.bin \ - ../test_wavs/1089-134686-0001.wav - - -The following table compares again the file sizes: - - -+----------------------------------------+------------+ -| File name | File size | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.pt | 283 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.pt | 1010 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.pt | 3.0 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB | -+----------------------------------------+------------+ -| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB | -+----------------------------------------+------------+ -| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB | -+----------------------------------------+------------+ -| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB | -+----------------------------------------+------------+ - -You can see that the file sizes of the model after ``int8`` quantization -are much smaller. - -.. hint:: - - Currently, only linear layers and convolutional layers are quantized - with ``int8``, so you don't see an exact ``4x`` reduction in file sizes. - -.. note:: - - You need to test the recognition accuracy after ``int8`` quantization. - -You can find the speed comparison at ``_. - - -That's it! Have fun with `sherpa-ncnn`_! + export-ncnn-zipformer + export-ncnn-conv-emformer + export-ncnn-lstm diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index dd4b3437a..aa77204cb 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -1,69 +1,95 @@ Export to ONNX ============== -In this section, we describe how to export models to ONNX. +In this section, we describe how to export models to `ONNX`_. + +In each recipe, there is a file called ``export-onnx.py``, which is used +to export trained models to `ONNX`_. + +There is also a file named ``onnx_pretrained.py``, which you can use +the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files. + +sherpa-onnx +----------- + +We have a separate repository `sherpa-onnx`_ for deploying your exported models +on various platforms such as: + + - iOS + - Android + - Raspberry Pi + - Linux/macOS/Windows + + +Please see the documentation of `sherpa-onnx`_ for details: + + ``_ + +Example +------- + +In the following, we demonstrate how to export a streaming Zipformer pre-trained +model from +``_ +to `ONNX`_. + +Download the pre-trained model +------------------------------ .. hint:: - Only non-streaming conformer transducer models are tested. - - -When to use it --------------- - -It you want to use an inference framework that supports ONNX -to run the pretrained model. - - -How to export -------------- - -We use -``_ -as an example in the following. + We assume you have installed `git-lfs`_. .. code-block:: bash - cd egs/librispeech/ASR - epoch=14 - avg=2 - ./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch $epoch \ - --avg $avg \ - --onnx 1 + cd egs/librispeech/ASR -It will generate the following files inside ``pruned_transducer_stateless3/exp``: + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) - - ``encoder.onnx`` - - ``decoder.onnx`` - - ``joiner.onnx`` - - ``joiner_encoder_proj.onnx`` - - ``joiner_decoder_proj.onnx`` + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained.pt" + cd exp + ln -s pretrained.pt epoch-99.pt + popd -You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode -waves with the generated files: +Export the model to ONNX +------------------------ .. code-block:: bash - ./pruned_transducer_stateless3/onnx_pretrained.py \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ - /path/to/foo.wav \ - /path/to/bar.wav \ - /path/to/baz.wav + ./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ +.. warning:: -How to use the exported model ------------------------------ + ``export-onnx.py`` from different recipes has different options. -We also provide ``_ -performing speech recognition using `onnxruntime `_ -with exported models. -It has been tested on Linux, macOS, and Windows. + In the above example, ``--decode-chunk-len`` is specific for the + streaming Zipformer. Other models won't have such an option. + +It will generate the following 3 files in ``$repo/exp`` + + - ``encoder-epoch-99-avg-1.onnx`` + - ``decoder-epoch-99-avg-1.onnx`` + - ``joiner-epoch-99-avg-1.onnx`` + +Decode sound files with exported ONNX models +-------------------------------------------- + +.. code-block:: bash + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst index 56a420605..aa73bfe33 100644 --- a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst +++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst @@ -299,11 +299,11 @@ to run the training part first. - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end of each epoch. You can pass ``--epoch`` to - ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved every ``--save-every-n`` batches. You can pass ``--iter`` to - ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them. + ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them. We suggest that you try both types of checkpoints and choose the one that produces the lowest WERs. @@ -311,7 +311,7 @@ to run the training part first. .. code-block:: bash $ cd egs/librispeech/ASR - $ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help + $ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help shows the options for decoding. @@ -320,7 +320,7 @@ The following shows the example using ``epoch-*.pt``: .. code-block:: bash for m in greedy_search fast_beam_search modified_beam_search; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 30 \ --avg 13 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ @@ -333,7 +333,7 @@ To test CTC branch, you can use the following command: .. code-block:: bash for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 30 \ --avg 13 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ @@ -367,7 +367,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p .. hint:: - To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``, + To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``, you can run: .. code-block:: bash @@ -376,7 +376,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p ln -s pretrained epoch-9999.pt And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to - ``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``. + ``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``. To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you can run: @@ -447,7 +447,8 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following links: - - ``_ + - trained on LibriSpeech 100h: ``_ + - trained on LibriSpeech 960h: ``_ See ``_ for the details of the above pretrained models diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index d81156659..e1382e77d 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -30,8 +30,9 @@ In icefall, we implement the streaming conformer the way just like what `WeNet < See :doc:`Pruned transducer statelessX ` for more details. .. HINT:: - If you want to adapt a non-streaming conformer model to be streaming, please refer - to `this pull request `_. + If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer + to `this pull request `_. After adding the code needed by streaming training, + you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. Streaming Emformer diff --git a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst index ce8ba1453..911e84656 100644 --- a/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst +++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst @@ -515,133 +515,6 @@ To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: Please see ``_ for how to use the exported models in ``sherpa``. -.. _export-lstm-transducer-model-for-ncnn: - -Export LSTM transducer models for ncnn -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We support exporting pretrained LSTM transducer models to -`ncnn `_ using -`pnnx `_. - -First, let us install a modified version of ``ncnn``: - -.. code-block:: bash - - git clone https://github.com/csukuangfj/ncnn - cd ncnn - git submodule update --recursive --init - - # Note: We don't use "python setup.py install" or "pip install ." here - - mkdir -p build-wheel - cd build-wheel - - cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DNCNN_PYTHON=ON \ - -DNCNN_BUILD_BENCHMARK=OFF \ - -DNCNN_BUILD_EXAMPLES=OFF \ - -DNCNN_BUILD_TOOLS=ON \ - .. - - make -j4 - - cd .. - - # Note: $PWD here is /path/to/ncnn - - export PYTHONPATH=$PWD/python:$PYTHONPATH - export PATH=$PWD/tools/pnnx/build/src:$PATH - export PATH=$PWD/build-wheel/tools/quantize:$PATH - - # now build pnnx - cd tools/pnnx - mkdir build - cd build - cmake .. - make -j4 - - ./src/pnnx - -.. note:: - - We assume that you have added the path to the binary ``pnnx`` to the - environment variable ``PATH``. - - We also assume that you have added ``build/tools/quantize`` to the environment - variable ``PATH`` so that you are able to use ``ncnn2int8`` later. - -Second, let us export the model using ``torch.jit.trace()`` that is suitable -for ``pnnx``: - -.. code-block:: bash - - iter=468000 - avg=16 - - ./lstm_transducer_stateless2/export.py \ - --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --iter $iter \ - --avg $avg \ - --pnnx 1 - -It will generate 3 files: - - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt`` - -Third, convert torchscript model to ``ncnn`` format: - -.. code-block:: - - pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt - pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt - pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt - -It will generate the following files: - - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param`` - - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin`` - -To use the above generated files, run: - -.. code-block:: bash - - ./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ - --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ - /path/to/foo.wav - -.. code-block:: bash - - ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ - --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ - /path/to/foo.wav - -To use the above generated files in C++, please see -``_ - -It is able to generate a static linked executable that can be run on Linux, Windows, -macOS, Raspberry Pi, etc, without external dependencies. - Download pretrained models -------------------------- @@ -657,6 +530,3 @@ by visiting the following links: You can find more usages of the pretrained models in ``_ - -Export ConvEmformer transducer models for ncnn -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index d0f118959..2512f233f 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -412,9 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 75fc6326e..f4a59e552 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,7 +1,7 @@ # Introduction -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index 20a4f21c7..fb6c7c481 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -388,18 +388,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -413,9 +409,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index bac829ae1..954d9dc7e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -406,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -431,9 +427,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index e019d2329..d23f4f883 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -325,17 +325,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -349,9 +345,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index 41cc1c01c..d164b6890 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -370,18 +370,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -395,9 +391,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 7c06e6e51..0a7d87fe8 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -374,18 +374,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -399,9 +395,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index b5da0959b..9e44b4e34 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -543,18 +543,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -564,9 +560,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 37d766ec8..068e2749a 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -406,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -427,9 +423,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index e4a90ef71..6c170c392 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -412,9 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py index 53381c1f4..2741e0eeb 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py @@ -462,18 +462,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -483,9 +479,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index f47228fbe..9999894d1 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -478,17 +478,13 @@ def save_results( test_set_wers = dict() test_set_cers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" - ) + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -499,9 +495,7 @@ def save_results( results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" - ) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" with open(cers_filename, "w") as f: cer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -512,9 +506,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: diff --git a/egs/csj/ASR/README.md b/egs/csj/ASR/README.md new file mode 100644 index 000000000..95c2ec6ac --- /dev/null +++ b/egs/csj/ASR/README.md @@ -0,0 +1,11 @@ +# Introduction + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +These are the types of architectures currently available. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming | diff --git a/egs/csj/ASR/RESULTS.md b/egs/csj/ASR/RESULTS.md new file mode 100644 index 000000000..56fdb899f --- /dev/null +++ b/egs/csj/ASR/RESULTS.md @@ -0,0 +1,200 @@ +# Results + +## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +Number of model parameters: 75688409, i.e. 75.7M. + +#### training on disfluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise | +| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode disfluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 0 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 2 \ + --num-decode-streams 40 + done +done +``` + +#### training on fluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise | +| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode fluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 1 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 3 \ + --num-decode-streams 40 + done +done +``` + +#### Comparing disfluent to fluent + +$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$ + +This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared. + +| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode | +| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- | +| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming | +| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise | +| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming | +| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise | +| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming | +| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise | +| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming | +| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise | +| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming | +| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise | +| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming | +| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise | +| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | | diff --git a/egs/csj/ASR/local/add_transcript_mode.py b/egs/csj/ASR/local/add_transcript_mode.py new file mode 100644 index 000000000..f6b4b2caf --- /dev/null +++ b/egs/csj/ASR/local/add_transcript_mode.py @@ -0,0 +1,94 @@ +import argparse +import logging +from configparser import ConfigParser +from pathlib import Path +from typing import List + +from lhotse import CutSet, SupervisionSet +from lhotse.recipes.csj import CSJSDBParser + +ARGPARSE_DESCRIPTION = """ +This script adds transcript modes to an existing CutSet or SupervisionSet. +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "-f", + "--fbank-dir", + type=Path, + help="Path to directory where manifests are stored.", + ) + parser.add_argument( + "-c", + "--config", + type=Path, + nargs="+", + help="Path to config file for transcript parsing.", + ) + return parser.parse_args() + + +def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]: + parsers = [] + for config_file in config_files: + config = ConfigParser() + config.optionxform = str + assert config.read(config_file), f"{config_file} could not be found." + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + parsers.append( + (config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions)) + ) + return parsers + + +def main(): + args = get_args() + logging.basicConfig( + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), + level=logging.INFO, + ) + parsers = get_CSJParsers(args.config) + config = ConfigParser() + config.optionxform = str + assert config.read(args.config), args.config + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + + logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.") + + manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz") + assert manifests, f"No cuts to be found in {args.fbank_dir}" + + for manifest in manifests: + results = [] + logging.info(f"Adding transcript modes to {manifest.name} now.") + cutset = CutSet.from_file(manifest) + for cut in cutset: + for name, parser in parsers: + cut.supervisions[0].custom[name] = parser.parse( + cut.supervisions[0].custom["raw"] + ) + cut.supervisions[0].text = "" + results.append(cut) + results = CutSet.from_items(results) + res_file = manifest.as_posix() + manifest.replace(manifest.parent / ("bak." + manifest.name)) + results.to_file(res_file) + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 667ad427e..ce560025d 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,9 +19,7 @@ import argparse import logging import os -from itertools import islice from pathlib import Path -from random import Random from typing import List, Tuple import torch @@ -35,20 +33,10 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre RecordingSet, SupervisionSet, ) +from lhotse.recipes.csj import concat_csj_supervisions # fmt: on -ARGPARSE_DESCRIPTION = """ -This script follows the espnet method of splitting the remaining core+noncore -utterances into valid and train cutsets at an index which is by default 4000. - -In other words, the core+noncore utterances are shuffled, where 4000 utterances -of the shuffled set go to the `valid` cutset and are not subject to speed -perturbation. The remaining utterances become the `train` cutset and are speed- -perturbed (0.9x, 1.0x, 1.1x). - -""" - # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -57,66 +45,101 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) RNG_SEED = 42 +# concat_params_train = [ +# {"gap": 1.0, "maxlen": 10.0}, +# {"gap": 1.5, "maxlen": 8.0}, +# {"gap": 1.0, "maxlen": 18.0}, +# ] + +concat_params = {"gap": 1.0, "maxlen": 10.0} def make_cutset_blueprints( manifest_dir: Path, - split: int, ) -> List[Tuple[str, CutSet]]: cut_sets = [] + logging.info("Creating non-train cuts.") + # Create eval datasets - logging.info("Creating eval cuts.") for i in range(1, 4): + sps = sorted( + SupervisionSet.from_file( + manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" + ), + key=lambda x: x.id, + ) + cut_set = CutSet.from_manifests( recordings=RecordingSet.from_file( manifest_dir / f"csj_recordings_eval{i}.jsonl.gz" ), - supervisions=SupervisionSet.from_file( - manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" - ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) cut_sets.append((f"eval{i}", cut_set)) - # Create train and valid cuts - logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") - recording_set = RecordingSet.from_file( - manifest_dir / "csj_recordings_core.jsonl.gz" - ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") - supervision_set = SupervisionSet.from_file( - manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") - + # Create excluded dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"), + key=lambda x: x.id, + ) cut_set = CutSet.from_manifests( - recordings=recording_set, - supervisions=supervision_set, + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_excluded.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_set = cut_set.shuffle(Random(RNG_SEED)) + cut_sets.append(("excluded", cut_set)) - logging.info( - "Creating valid and train cuts from core and noncore, split at {split}." + # Create valid dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"), + key=lambda x: x.id, ) - valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) + cut_set = CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_valid.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_sets.append(("valid", cut_set)) - train_set = CutSet.from_cuts(islice(cut_set, split, None)) + logging.info("Creating train cuts.") + + # Create train dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz") + + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"), + key=lambda x: x.id, + ) + + recording = RecordingSet.from_file( + manifest_dir / "csj_recordings_core.jsonl.gz" + ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") + + train_set = CutSet.from_manifests( + recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params) + ).trim_to_supervisions(keep_overlapping=False) train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - cut_sets.extend([("valid", valid_set), ("train", train_set)]) + cut_sets.append(("train", train_set)) return cut_sets def get_args(): parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") - parser.add_argument("--split", type=int, default=4000, help="Split at this index") + parser.add_argument( + "-m", "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() @@ -138,7 +161,7 @@ def main(): ) return else: - cut_sets = make_cutset_blueprints(args.manifest_dir, args.split) + cut_sets = make_cutset_blueprints(args.manifest_dir) for part, cut_set in cut_sets: logging.info(f"Processing {part}") cut_set = cut_set.compute_and_store_features( @@ -147,7 +170,7 @@ def main(): storage_path=(args.fbank_dir / f"feats_{part}").as_posix(), storage_type=LilcomChunkyWriter, ) - cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz") + cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz") logging.info("All fbank computed for CSJ.") (args.fbank_dir / ".done").touch() diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index f60e62c85..c942df98e 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -28,9 +28,7 @@ from icefall.utils import get_executor ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. -The generated fbank features are saved in data/fbank. """ # Torch's multithreaded behavior needs to be disabled or @@ -42,8 +40,6 @@ torch.set_num_interop_threads(1) def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): - # src_dir = Path("data/manifests") - # output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 @@ -104,8 +100,12 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument( + "-m", "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index c987e72c5..4f0a9ec0e 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = disfluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 0 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 0 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 0 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 0 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 0 -; # Example: '(X (D2 ノ))' -D2^ = 0 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d033ed17 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = fluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..3ada9aa24 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = number -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) A_num = 1 -A_num^ = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..dafd65c9a 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,321 +1,80 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode -; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf +; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf MODE = symbol -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' -F = # -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = # +F = "#", ["F"] ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = @ -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = @ +D = "@", ["D"] ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = @ -; # Example: '(X (D2 ノ))' -D2^ = @ +D2 = "@", ["D2"] ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/disfluent_recogs_to_fluent.py b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py new file mode 100644 index 000000000..45c9c7656 --- /dev/null +++ b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py @@ -0,0 +1,202 @@ +import argparse +from pathlib import Path + +import kaldialign +from lhotse import CutSet + +ARGPARSE_DESCRIPTION = """ +This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript, +compares it against a fluent transcript, and saves the results in a separate directory. +This is useful to compare disfluent models with fluent models on the same metric. + +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "--recogs", + type=Path, + required=True, + help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.", + ) + parser.add_argument( + "--cut", + type=Path, + required=True, + help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.", + ) + parser.add_argument( + "--res-dir", type=Path, required=True, help="Path to save results" + ) + return parser.parse_args() + + +def d2f(stats): + """ + Compare the outputs of a disfluent model against a fluent reference. + Indicates a disfluent model's performance only on the content words + + CER^d_f = (sub_f + ins + del_f) / Nf + + """ + return stats["base"] / stats["Nf"] + + +def calc_cer(refs, hyps): + subs = { + "F": 0, + "D": 0, + } + ins = 0 + dels = { + "F": 0, + "D": 0, + } + cors = { + "F": 0, + "D": 0, + } + dis_ref_len = 0 + flu_ref_len = 0 + + for ref, hyp in zip(refs, hyps): + assert ( + ref[0] == hyp[0] + ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}." + tag = ref[2].copy() + ref = ref[1] + dis_ref_len += len(ref) + # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively. + flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)]) + hyp = hyp[1] + ali = kaldialign.align(ref, hyp, "*") + tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali] + for tag, (ref_word, hyp_word) in zip(tags, ali): + if "D" in tag or "F" in tag: + tag = "D" + else: + tag = "F" + + if ref_word == "*": + ins += 1 + elif hyp_word == "*": + dels[tag] += 1 + elif ref_word != hyp_word: + subs[tag] += 1 + else: + cors[tag] += 1 + + return { + "subs": subs, + "ins": ins, + "dels": dels, + "cors": cors, + "dis_ref_len": dis_ref_len, + "flu_ref_len": flu_ref_len, + } + + +def for_each_recogs(recogs_file: Path, refs, out_dir): + hyps = [] + with recogs_file.open() as fin: + for line in fin: + if "ref" in line: + continue + cutid, hyp = line.split(":\thyp=") + hyps.append((cutid, eval(hyp))) + + assert len(refs) == len( + hyps + ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal." + stats = calc_cer(refs, hyps) + stat_table = ["tag,yes,no"] + + for cer_type in ["subs", "dels", "cors", "ins"]: + ret = f"{cer_type}" + for df in ["D", "F"]: + try: + ret += f",{stats[cer_type][df]}" + except TypeError: + # insertions do not belong to F or D, and is not subscriptable. + ret += f",{stats[cer_type]}," + break + stat_table.append(ret) + stat_table = "\n".join(stat_table) + + stats = { + "subd": stats["subs"]["D"], + "deld": stats["dels"]["D"], + "cord": stats["cors"]["D"], + "Nf": stats["flu_ref_len"], + "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"], + } + + cer = d2f(stats) + results = [ + f"{cer:.2%}", + f"Nf,{stats['Nf']}", + ] + results = "\n".join(results) + + with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout: + fout.write(results) + fout.write("\n\n") + fout.write(stat_table) + + +def main(): + args = get_args() + recogs_file: Path = args.recogs + assert ( + recogs_file.is_file() or recogs_file.is_dir() + ), f"recogs_file cannot be found at {recogs_file}." + + args.res_dir.mkdir(parents=True, exist_ok=True) + + if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"): + assert ( + "csj_cuts" in args.cut.name + ), f"Expected {args.cut} to be a cuts manifest." + + refs: CutSet = CutSet.from_file(args.cut) + refs = sorted( + [ + ( + e.id, + list(e.supervisions[0].custom["disfluent"]), + e.supervisions[0].custom["disfluent_tag"].split(","), + ) + for e in refs + ], + key=lambda x: x[0], + ) + for_each_recogs(recogs_file, refs, args.res_dir) + + elif recogs_file.is_dir(): + recogs_file_path = recogs_file + for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]: + refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz") + refs = sorted( + [ + ( + r.id, + list(r.supervisions[0].custom["disfluent"]), + r.supervisions[0].custom["disfluent_tag"].split(","), + ) + for r in refs + ], + key=lambda x: x[0], + ) + for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"): + for_each_recogs(recogs_file, refs, args.res_dir) + + else: + raise TypeError(f"Unrecognised recogs file provided: {recogs_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..924474d33 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -45,8 +45,8 @@ def get_parser(): def main(): args = get_parser() - for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"): - + for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]: + path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz" cuts: CutSet = load_manifest(path) print("\n---------------------------------\n") @@ -58,123 +58,271 @@ if __name__ == "__main__": main() """ -## eval1 -Cuts count: 1272 -Total duration (hh:mm:ss): 01:50:07 -Speech duration (hh:mm:ss): 01:50:07 (100.0%) -Duration statistics (seconds): -mean 5.2 -std 3.9 -min 0.2 -25% 1.9 -50% 4.0 -75% 8.1 -99% 14.3 -99.5% 14.7 -99.9% 16.0 -max 16.9 -Recordings available: 1272 -Features available: 1272 -Supervisions available: 1272 +csj_cuts_eval1.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:55:40 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.8 │ +├───────────────────────────┼──────────┤ +│ std │ 2.7 │ +├───────────────────────────┼──────────┤ +│ min │ 0.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.7 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1023 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- fluent (in 1272 cuts) -- disfluent (in 1272 cuts) -- number (in 1272 cuts) -- symbol (in 1272 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## eval2 -Cuts count: 1292 -Total duration (hh:mm:ss): 01:56:50 -Speech duration (hh:mm:ss): 01:56:50 (100.0%) -Duration statistics (seconds): -mean 5.4 -std 3.9 -min 0.1 -25% 2.1 -50% 4.6 -75% 8.6 -99% 14.1 -99.5% 15.2 -99.9% 16.1 -max 16.9 -Recordings available: 1292 -Features available: 1292 -Supervisions available: 1292 -SUPERVISION custom fields: -- fluent (in 1292 cuts) -- number (in 1292 cuts) -- symbol (in 1292 cuts) -- disfluent (in 1292 cuts) +--------------------------------- -## eval3 -Cuts count: 1385 -Total duration (hh:mm:ss): 01:19:21 -Speech duration (hh:mm:ss): 01:19:21 (100.0%) -Duration statistics (seconds): -mean 3.4 -std 3.0 -min 0.2 -25% 1.2 -50% 2.5 -75% 4.6 -99% 12.7 -99.5% 13.7 -99.9% 15.0 -max 15.9 -Recordings available: 1385 -Features available: 1385 -Supervisions available: 1385 +csj_cuts_eval2.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:02:07 │ +├───────────────────────────┼──────────┤ +│ mean │ 7.1 │ +├───────────────────────────┼──────────┤ +│ std │ 2.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1025 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- number (in 1385 cuts) -- symbol (in 1385 cuts) -- fluent (in 1385 cuts) -- disfluent (in 1385 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## valid -Cuts count: 4000 -Total duration (hh:mm:ss): 05:08:09 -Speech duration (hh:mm:ss): 05:08:09 (100.0%) -Duration statistics (seconds): -mean 4.6 -std 3.8 -min 0.1 -25% 1.5 -50% 3.4 -75% 7.0 -99% 13.8 -99.5% 14.8 -99.9% 16.0 -max 17.3 -Recordings available: 4000 -Features available: 4000 -Supervisions available: 4000 -SUPERVISION custom fields: -- fluent (in 4000 cuts) -- symbol (in 4000 cuts) -- disfluent (in 4000 cuts) -- number (in 4000 cuts) +--------------------------------- -## train -Cuts count: 1291134 -Total duration (hh:mm:ss): 1596:37:27 -Speech duration (hh:mm:ss): 1596:37:27 (100.0%) -Duration statistics (seconds): -mean 4.5 -std 3.6 -min 0.0 -25% 1.6 -50% 3.3 -75% 6.4 -99% 14.0 -99.5% 14.8 -99.9% 16.6 -max 27.8 -Recordings available: 1291134 -Features available: 1291134 -Supervisions available: 1291134 +csj_cuts_eval3.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 865 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:26:44 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.3 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.3 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.7 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 865 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 865 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- disfluent (in 1291134 cuts) -- fluent (in 1291134 cuts) -- symbol (in 1291134 cuts) -- number (in 1291134 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_valid.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 06:40:15 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.4 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.1 │ +├───────────────────────────┼──────────┤ +│ max │ 11.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 3743 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_excluded.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 980 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:56:06 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.1 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 0.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.2 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 980 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 980 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_train.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤════════════╕ +│ Cuts count: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Total duration (hh:mm:ss) │ 1695:29:43 │ +├───────────────────────────┼────────────┤ +│ mean │ 6.7 │ +├───────────────────────────┼────────────┤ +│ std │ 2.9 │ +├───────────────────────────┼────────────┤ +│ min │ 0.1 │ +├───────────────────────────┼────────────┤ +│ 25% │ 4.6 │ +├───────────────────────────┼────────────┤ +│ 50% │ 7.5 │ +├───────────────────────────┼────────────┤ +│ 75% │ 8.9 │ +├───────────────────────────┼────────────┤ +│ 99% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.5% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.9% │ 11.1 │ +├───────────────────────────┼────────────┤ +│ max │ 18.0 │ +├───────────────────────────┼────────────┤ +│ Recordings available: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼────────────┤ +│ Supervisions available: │ 914151 │ +╘═══════════════════════════╧════════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤════════════╤══════════════════════╕ +│ Total speech duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧════════════╧══════════════════════╛ """ diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index 16107f543..58b197922 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -21,24 +21,14 @@ import logging from pathlib import Path from lhotse import CutSet +from lhotse.recipes.csj import CSJSDBParser ARGPARSE_DESCRIPTION = """ -This script gathers all training transcripts of the specified {trans_mode} type -and produces a token_list that would be output set of the ASR system. +This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system. -It splits transcripts by whitespace into lists, then, for each word in the -list, if the word does not appear in the list of user-defined multicharacter -strings, it further splits that word into individual characters to be counted -into the output token set. - -It outputs 4 files into the lang directory: -- trans_mode: the name of transcript mode. If trans_mode was not specified, - this will be an empty file. -- userdef_string: a list of user defined strings that should not be split - further into individual characters. By default, it contains "", "", - "" -- words_len: the total number of tokens in the output set. -- words.txt: a list of tokens in the output set. The length matches words_len. +It outputs 3 files into the lang directory: +- tokens.txt: a list of tokens in the output set. +- lang_type: a file that contains the string "char" """ @@ -50,98 +40,52 @@ def get_args(): ) parser.add_argument( - "--train-cut", type=Path, required=True, help="Path to the train cut" - ) - - parser.add_argument( - "--trans-mode", - type=str, - default=None, - help=( - "Name of the transcript mode to use. " - "If lang-dir is not set, this will also name the lang-dir" - ), + "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" ) parser.add_argument( "--lang-dir", type=Path, - default=None, + default=Path("data/lang_char"), help=( "Name of lang dir. " "If not set, this will default to lang_char_{trans-mode}" ), ) - parser.add_argument( - "--userdef-string", - type=Path, - default=None, - help="Multicharacter strings that do not need to be split", - ) - return parser.parse_args() def main(): args = get_args() - logging.basicConfig( format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), level=logging.INFO, ) - if not args.lang_dir: - p = "lang_char" - if args.trans_mode: - p += f"_{args.trans_mode}" - args.lang_dir = Path(p) + sysdef_string = set(["", "", ""]) - if args.userdef_string: - args.userdef_string = set(args.userdef_string.read_text().split()) - else: - args.userdef_string = set() + # Using disfluent parsing as fluent is a subset of disfluent + parser = CSJSDBParser() - sysdef_string = ["", "", ""] - args.userdef_string.update(sysdef_string) + token_set = set() + logging.info(f"Creating vocabulary from {args.train_cut}.") + train_cut: CutSet = CutSet.from_file(args.train_cut) + for cut in train_cut: + if "_sp" in cut.id: + continue - train_set: CutSet = CutSet.from_file(args.train_cut) - - words = set() - logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." - ) - for cut in train_set: - try: - text: str = ( - cut.supervisions[0].custom[args.trans_mode] - if args.trans_mode - else cut.supervisions[0].text - ) - except KeyError: - raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" - ) - for t in text.split(): - if t in args.userdef_string: - words.add(t) - else: - words.update(c for c in list(t)) - - words -= set(sysdef_string) - words = sorted(words) - words = [""] + words + ["", ""] + text: str = cut.supervisions[0].custom["raw"] + for w in parser.parse(text, sep=" ").split(" "): + token_set.update(w) + token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "words.txt").write_text( - "\n".join(f"{word}\t{i}" for i, word in enumerate(words)) + (args.lang_dir / "tokens.txt").write_text( + "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) ) - (args.lang_dir / "words_len").write_text(f"{len(words)}") - - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) - - (args.lang_dir / "trans_mode").write_text(args.trans_mode) + (args.lang_dir / "lang_type").write_text("char") logging.info("Done.") diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py new file mode 100644 index 000000000..619820a75 --- /dev/null +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -0,0 +1,462 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset): + def __init__( + self, + *args, + transcript_mode: str = "", + return_cuts: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.transcript_mode = transcript_mode + self.return_cuts = True + self._return_cuts = return_cuts + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + batch = super().__getitem__(cuts) + + if self.transcript_mode: + batch["supervisions"]["text"] = [ + supervision.custom[self.transcript_mode] + for cut in batch["supervisions"]["cut"] + for supervision in cut.supervisions + ] + + if not self._return_cuts: + del batch["supervisions"]["cut"] + + return batch + + +class CSJAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--transcript-mode", + type=str, + default="", + help="Mode of transcript in supervision to use.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--musan-dir", type=Path, help="Path to directory with musan cuts. " + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = AsrVariableTranscriptDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + else: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + + test = AsrVariableTranscriptDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get valid cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz") + + @lru_cache() + def excluded_cuts(self) -> CutSet: + logging.info("About to get excluded cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz") + + @lru_cache() + def eval1_cuts(self) -> CutSet: + logging.info("About to get eval1 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz") + + @lru_cache() + def eval2_cuts(self) -> CutSet: + logging.info("About to get eval2 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz") + + @lru_cache() + def eval3_cuts(self) -> CutSet: + logging.info("About to get eval3 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz") diff --git a/egs/csj/ASR/local/utils/tokenizer.py b/egs/csj/ASR/local/utils/tokenizer.py new file mode 100644 index 000000000..c9be72be1 --- /dev/null +++ b/egs/csj/ASR/local/utils/tokenizer.py @@ -0,0 +1,253 @@ +import argparse +from pathlib import Path +from typing import Callable, List, Union + +import sentencepiece as spm +from k2 import SymbolTable + + +class Tokenizer: + text2word: Callable[[str], List[str]] + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Lang related options") + + group.add_argument("--lang", type=Path, help="Path to lang directory.") + + group.add_argument( + "--lang-type", + type=str, + default=None, + help=( + "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " + "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" + ), + ) + + @staticmethod + def Load(lang_dir: Path, lang_type="", oov=""): + + if not lang_type: + assert (lang_dir / "lang_type").exists(), "lang_type not specified." + lang_type = (lang_dir / "lang_type").read_text().strip() + + tokenizer = None + + if lang_type == "bpe": + assert ( + lang_dir / "bpe.model" + ).exists(), f"No BPE .model could be found in {lang_dir}." + tokenizer = spm.SentencePieceProcessor() + tokenizer.Load(str(lang_dir / "bpe.model")) + elif lang_type == "char": + tokenizer = CharTokenizer(lang_dir, oov=oov) + else: + raise NotImplementedError(f"{lang_type} not supported at the moment.") + + return tokenizer + + load = Load + + def PieceToId(self, piece: str) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + piece_to_id = PieceToId + + def IdToPiece(self, id: int) -> str: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + id_to_piece = IdToPiece + + def GetPieceSize(self) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + get_piece_size = GetPieceSize + + def __len__(self) -> int: + return self.get_piece_size() + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsIds(self, input: str) -> List[int]: + return self.EncodeAsIdsBatch([input])[0] + + def EncodeAsPieces(self, input: str) -> List[str]: + return self.EncodeAsPiecesBatch([input])[0] + + def Encode( + self, input: Union[str, List[str]], out_type=int + ) -> Union[List, List[List]]: + if not input: + return [] + + if isinstance(input, list): + if out_type is int: + return self.EncodeAsIdsBatch(input) + if out_type is str: + return self.EncodeAsPiecesBatch(input) + + if out_type is int: + return self.EncodeAsIds(input) + if out_type is str: + return self.EncodeAsPieces(input) + + encode = Encode + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodeIds(self, input: List[int]) -> str: + return self.DecodeIdsBatch([input])[0] + + def DecodePieces(self, input: List[str]) -> str: + return self.DecodePiecesBatch([input])[0] + + def Decode( + self, + input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], + ) -> Union[List[str], str]: + + if not input: + return "" + + if isinstance(input, int): + return self.id_to_piece(input) + elif isinstance(input, str): + raise TypeError( + "Unlike spm.SentencePieceProcessor, cannot decode from type str." + ) + + if isinstance(input[0], list): + if not input[0] or isinstance(input[0][0], int): + return self.DecodeIdsBatch(input) + + if isinstance(input[0][0], str): + return self.DecodePiecesBatch(input) + + if isinstance(input[0], int): + return self.DecodeIds(input) + if isinstance(input[0], str): + return self.DecodePieces(input) + + raise RuntimeError("Unknown input type") + + decode = Decode + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: + if isinstance(input, list): + return self.SplitBatch(input) + elif isinstance(input, str): + return self.SplitBatch([input])[0] + raise RuntimeError("Unknown input type") + + split = Split + + +class CharTokenizer(Tokenizer): + def __init__(self, lang_dir: Path, oov="", sep=""): + assert ( + lang_dir / "tokens.txt" + ).exists(), f"tokens.txt could not be found in {lang_dir}." + token_table = SymbolTable.from_file(lang_dir / "tokens.txt") + assert ( + "#0" not in token_table + ), "This tokenizer does not support disambig symbols." + self._id2sym = token_table._id2sym + self._sym2id = token_table._sym2id + self.oov = oov + self.oov_id = self._sym2id[oov] + self.sep = sep + if self.sep: + self.text2word = lambda x: x.split(self.sep) + else: + self.text2word = lambda x: list(x.replace(" ", "")) + + def piece_to_id(self, piece: str) -> int: + try: + return self._sym2id[piece] + except KeyError: + return self.oov_id + + def id_to_piece(self, id: int) -> str: + return self._id2sym[id] + + def get_piece_size(self) -> int: + return len(self._sym2id) + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + return [ + [i if i in self._sym2id else self.oov for i in self.text2word(text)] + for text in input + ] + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + return [self.sep.join(text) for text in input] + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + return [self.text2word(text) for text in input] + + +def test_CharTokenizer(): + test_single_string = "こんにちは" + test_multiple_string = [ + "今日はいい天気ですよね", + "諏訪湖は綺麗でしょう", + "这在词表外", + "分かち 書き に し た 文章 です", + "", + ] + test_empty_string = "" + sp = Tokenizer.load(Path("lang_char"), "char", oov="") + splitter = sp.split + print(sp.encode(test_single_string, out_type=str)) + print(sp.encode(test_single_string, out_type=int)) + print(sp.encode(test_multiple_string, out_type=str)) + print(sp.encode(test_multiple_string, out_type=int)) + print(sp.encode(test_empty_string, out_type=str)) + print(sp.encode(test_empty_string, out_type=int)) + print(sp.decode(sp.encode(test_single_string, out_type=str))) + print(sp.decode(sp.encode(test_single_string, out_type=int))) + print(sp.decode(sp.encode(test_multiple_string, out_type=str))) + print(sp.decode(sp.encode(test_multiple_string, out_type=int))) + print(sp.decode(sp.encode(test_empty_string, out_type=str))) + print(sp.decode(sp.encode(test_empty_string, out_type=int))) + print(splitter(test_single_string)) + print(splitter(test_multiple_string)) + print(splitter(test_empty_string)) + + +if __name__ == "__main__": + test_CharTokenizer() diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh index c4ce91984..52339bb35 100755 --- a/egs/csj/ASR/prepare.sh +++ b/egs/csj/ASR/prepare.sh @@ -32,7 +32,7 @@ # - speech # # By default, this script produces the original transcript like kaldi and espnet. Optionally, you -# can generate other transcript formats by supplying your own config files. A few examples of these +# can add other transcript formats by supplying your own config files. A few examples of these # config files can be found in local/conf. # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 @@ -44,10 +44,10 @@ nj=8 stage=-1 stop_stage=100 -csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ -musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan -trans_dir=$csj_dir/retranscript -csj_fbank_dir=/mnt/host/csj_data/fbank +csj_dir=/mnt/host/corpus/csj +musan_dir=/mnt/host/corpus/musan/musan +trans_dir=$csj_dir/transcript +csj_fbank_dir=/mnt/host/corpus/csj/fbank musan_fbank_dir=$musan_dir/fbank csj_manifest_dir=data/manifests musan_manifest_dir=$musan_dir/manifests @@ -63,12 +63,8 @@ log() { if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare CSJ manifest" - # If you want to generate more transcript modes, append the path to those config files at c. - # Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini - # NOTE: In case multiple config files are supplied, the second config file and onwards will inherit - # the segment boundaries of the first config file. if [ ! -e $csj_manifest_dir/.csj.done ]; then - lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4 + lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16 touch $csj_manifest_dir/.csj.done fi fi @@ -88,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ --fbank-dir $csj_fbank_dir parts=( - train - valid eval1 eval2 eval3 + valid + excluded + train ) for part in ${parts[@]}; do - python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz + python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz done touch $csj_fbank_dir/.csj-validated.done fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare CSJ lang" - modes=disfluent - - # If you want prepare the lang directory for other transcript modes, just append - # the names of those modes behind. An example is shown as below:- - # modes="$modes fluent symbol number" - - for mode in ${modes[@]}; do - python local/prepare_lang_char.py --trans-mode $mode \ - --train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \ - --lang-dir lang_char_$mode - done + log "Stage 4: Prepare CSJ lang_char" + python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz + python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -128,6 +116,6 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Show manifest statistics" - python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt - cat $csj_manifest_dir/manifest_statistics.txt + python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt + cat $csj_fbank_dir/manifest_statistics.txt fi diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py new file mode 100644 index 000000000..f5235207a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py @@ -0,0 +1,76 @@ +import logging +from configparser import ConfigParser + +import requests + + +def escape_html(text: str): + """ + Escapes all html characters in text + :param str text: + :rtype: str + """ + return text.replace("&", "&").replace("<", "<").replace(">", ">") + + +class TelegramStreamIO(logging.Handler): + + API_ENDPOINT = "https://api.telegram.org" + MAX_MESSAGE_LEN = 4096 + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s at %(funcName)s " + "(line %(lineno)s):\n\n%(message)s" + ) + + def __init__(self, tg_configfile: str): + super(TelegramStreamIO, self).__init__() + config = ConfigParser() + if not config.read(tg_configfile): + raise FileNotFoundError( + f"{tg_configfile} not found. " "Retry without --telegram-cred flag." + ) + config = config["TELEGRAM"] + token = config["token"] + self.chat_id = config["chat_id"] + self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage" + + @staticmethod + def setup_logger(params): + if not params.telegram_cred: + return + formatter = logging.Formatter( + f"{params.exp_dir.name} %(asctime)s \n%(message)s" + ) + tg = TelegramStreamIO(params.telegram_cred) + tg.setLevel(logging.WARN) + tg.setFormatter(formatter) + logging.getLogger("").addHandler(tg) + + def emit(self, record: logging.LogRecord): + """ + Emit a record. + Send the record to the Web server as a percent-encoded dictionary + """ + data = { + "chat_id": self.chat_id, + "text": self.format(self.mapLogRecord(record)), + "parse_mode": "HTML", + } + try: + requests.get(self.url, json=data) + # return response.json() + except Exception as e: + logging.error(f"Failed to send telegram message: {repr(e)}") + pass + + def mapLogRecord(self, record): + """ + Default implementation of mapping the log record into a dict + that is sent as the CGI data. Overwrite in your class. + Contributed by Franz Glasner. + """ + + for k, v in record.__dict__.items(): + if isinstance(v, str): + setattr(record, k, escape_html(v)) + return record diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a48591198 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..f5a1d750d --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,846 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --lang data/lang_char \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir. It should contain at least a word table.", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=30, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.text2word(sp.decode(hyp))) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = sp.text2word(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + return test_set_wers + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + if not params.res_dir: + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and are defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + decoding_graph = None + word_table = None + + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, + test_set_name=subdir, + results_dict=results_dict, + ) + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..1ce277aa6 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 100755 index 000000000..ebdb596a5 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/csj/ASR + +repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp_fluent/pretrained.pt" + +cd exp_fluent +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --lang $repo/data/lang_char \ + --exp-dir $repo/exp_fluent/ \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp_fluent + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14 + +Please also have a look at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-ja-fluent-2023-02-14/blob/main/export-for-ncnn-fluent.sh + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + assert params.blank_id == 0, params.blank_id + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100644 index 000000000..2d45ecca3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/csj/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang data/lang_char + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + # You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100644 index 000000000..ab7c8748a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 + +""" +Usage: +# use -O to skip assertions and avoid some of the TracerWarnings +python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100644 index 000000000..d84cf04a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --lang data/lang_char \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +from typing import List, Optional + +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = Tokenizer.load(args.lang, args.lang_type) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..482ebcfef --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..16c2bf28d --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..522bbaff9 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100644 index 000000000..932026868 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..a7ef73bcb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..566c317ff --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 120000 index 000000000..92c3904af --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..9700dd89e --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --lang data/lang_char \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decode import save_results +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if not params.res_dir: + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + cuts=getattr(csj_corpus, f"{subdir}_cuts")(), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, test_set_name=subdir, results_dict=results_dict + ) + + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..0a82ccfa4 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/csj/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py new file mode 120000 index 000000000..958c99e85 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py @@ -0,0 +1 @@ +../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..601de2c41 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1304 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..d1913d718 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1305 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py new file mode 120000 index 000000000..12dbda888 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 8595c27bd..ee694a9e0 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -399,9 +399,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = post_processing(results) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -409,9 +407,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -421,9 +417,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 94cb445a8..9ffd78d5b 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,6 +1,6 @@ # Introduction -Please refer to for how to run models in this recipe. +Please refer to for how to run models in this recipe. [./RESULTS.md](./RESULTS.md) contains the latest results. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index b30cf7c1f..ecb84eb01 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -77,6 +77,18 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +#### Smaller model + +A smaller model (~20M params) is also available with configuration based on [this comment](https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740). The WERs are: + +| decoding method | chunk size | test-clean | test-other | comment | decoding mode | +|----------------------|------------|------------|------------|---------------------|----------------------| +| greedy search | 320ms | 3.94 | 9.79 | --epoch 30 --avg 9 | simulated streaming | +| modified beam search | 320ms | 3.88 | 9.53 | --epoch 30 --avg 9 | simulated streaming | + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + ### zipformer_mmi (zipformer with mmi loss) @@ -93,13 +105,13 @@ results at: Number of model parameters: 69136519, i.e., 69.14 M -| | test-clean | test-other | comment | -|--------------------------|------------|-------------|---------------------| -| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | -| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | -| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | -| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | -| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | +| | test-clean | test-other | comment | +| ---------------------- | ---------- | ---------- | ------------------- | +| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 | +| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 | +| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 | +| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 | +| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 | The training commands are: ```bash @@ -134,6 +146,97 @@ for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; done ``` +### pruned_transducer_stateless7_ctc_bs (zipformer with transducer loss and ctc loss using blank skip) + +See https://github.com/k2-fsa/icefall/pull/730 for more details. + +[pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 76804822, i.e., 76.80 M + +Test on 8-card V100 cluster, with 4-card busy and 4-card idle. + +#### greedy_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.28 | 5.53 | 48.939 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.24 | 5.18 | 91.900 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **1.88 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +#### modified_beam_search + +| model | test-clean | test-other | decoding time(s) | comment | +| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- | +| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.26 | 5.44 | 80.446 | --epoch 30 --avg 13 | +| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.20 | 5.12 | 283.676 | --epoch 30 --avg 8 | + +- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip. +- Applying blank skip both on training and decoding is **3.53 times** faster than the model that doesn't apply blank skip without obvious performance loss. + +The training commands for the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc_bs/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --ctc-loss-scale 0.2 \ + --master-port 12535 +``` + +The decoding commands for the transducer branch of the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +The decoding commands for the transducer branch of the model without blank skip ([pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)) are: + +```bash +for m in greedy_search modified_beam_search fast_beam_search; do + for epoch in 30; do + for avg in 15; do + ./pruned_transducer_stateless7_ctc/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` ### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss) diff --git a/egs/librispeech/ASR/add_alignments.sh b/egs/librispeech/ASR/add_alignments.sh index 5e4480bf6..6c47d25a2 100755 --- a/egs/librispeech/ASR/add_alignments.sh +++ b/egs/librispeech/ASR/add_alignments.sh @@ -2,11 +2,51 @@ set -eou pipefail -alignments_dir=data/alignment +# align could be in ("mfa", "torchaudio") +# We recommend "torchaudio" +align="torchaudio" + +# It adds alignments to the existing fbank features dir (e.g., data/fbank) +# and save cuts to a new dir (e.g., data/fbank_ali). cuts_in_dir=data/fbank cuts_out_dir=data/fbank_ali -python3 ./local/add_alignment_librispeech.py \ - --alignments-dir $alignments_dir \ - --cuts-in-dir $cuts_in_dir \ - --cuts-out-dir $cuts_out_dir +if [ $align == "mfa" ]; then + # It add alignments from https://github.com/CorentinJ/librispeech-alignments, + # generated using the Montreal Forced Aligner (https://montreal-forced-aligner.readthedocs.io). + alignments_dir=data/alignment + + python3 ./local/add_alignment_librispeech.py \ + --alignments-dir $alignments_dir \ + --cuts-in-dir $cuts_in_dir \ + --cuts-out-dir $cuts_out_dir +elif [ $align == "torchaudio" ]; then + # See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/bin/modes/workflows.py for details. + # + # It use a pretrained ASR model from torchaudio to generate alignments. + # It will attach word-level alignment information (start, end, and score) to the + # supervisions in each cut. + mkdir -p $cuts_out_dir + + parts=( + train-clean-100 + train-clean-360 + train-other-500 + test-clean + test-other + dev-clean + dev-other + ) + + echo "The alignments will be saved to $cuts_out_dir" + for part in ${parts[@]}; do + echo "Start to align $part" + lhotse workflows align-with-torchaudio --dont-normalize-text \ + $cuts_in_dir/librispeech_cuts_${part}.jsonl.gz \ + $cuts_out_dir/librispeech_cuts_${part}.jsonl.gz + done + echo "Finished" +else + echo "align is expected to be in ('mfa', 'torchaudio'), but got $align" + exit 1 +fi diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 8eca2ae02..e6327bb5e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -40,10 +40,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./conformer_ctc3/decode.py \ @@ -58,7 +55,6 @@ For example: --left-context 64 \ --manifest-dir data/fbank_ali Note: It supports calculating symbol delay with following decoding methods: - - ctc-greedy-search - ctc-decoding - 1best """ @@ -96,11 +92,11 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - DecodingResults, + convert_timestamp, get_texts, - get_texts_with_timestamp, make_pad_mask, - parse_hyp_and_timestamp, + parse_bpe_start_end_pairs, + parse_fsa_timestamps_and_texts, setup_logger, store_transcripts_and_timestamps, str2bool, @@ -174,11 +170,12 @@ def get_parser(): default="ctc-decoding", help="""Decoding method. Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + - (0) ctc-greedy-search. It uses a sentence piece model, + i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece model, i.e., lang_dir/bpe.model, to convert word pieces to words. It needs neither a lexicon nor an n-gram LM. - - (1) ctc-greedy-search. It only use CTC output and a sentence piece - model for decoding. It produces the same results with ctc-decoding. - (2) 1best. Extract the best path from the decoding lattice as the decoding result. - (3) nbest. Extract n paths from the decoding lattice; the path @@ -250,6 +247,14 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + add_model_arguments(parser) return parser @@ -273,41 +278,95 @@ def get_decoding_params() -> AttributeDict: def ctc_greedy_search( ctc_probs: torch.Tensor, nnet_output_lens: torch.Tensor, -) -> List[List[int]]: + sp: spm.SentencePieceProcessor, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: """Apply CTC greedy search Args: - ctc_probs (torch.Tensor): (batch, max_len, feat_dim) - nnet_output_lens (torch.Tensor): (batch, ) + ctc_probs (torch.Tensor): + (batch, max_len, feat_dim) + nnet_output_lens (torch.Tensor): + (batch, ) + sp: + The BPE model. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + Returns: - List[List[int]]: best path result + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. """ topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.squeeze(2) # (B, maxlen) mask = make_pad_mask(nnet_output_lens) topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] - scores = topk_prob.max(1) - ret_hyps = [] - timestamps = [] - for i in range(len(hyps)): - hyp, time = remove_duplicates_and_blank(hyps[i]) - ret_hyps.append(hyp) - timestamps.append(time) - return ret_hyps, timestamps, scores + + def get_first_tokens(tokens: List[int]) -> List[bool]: + is_first_token = [] + first_tokens = [] + for t in range(len(tokens)): + if tokens[t] != 0 and (t == 0 or tokens[t - 1] != tokens[t]): + is_first_token.append(True) + first_tokens.append(tokens[t]) + else: + is_first_token.append(False) + return first_tokens, is_first_token + + utt_time_pairs = [] + utt_words = [] + for utt in range(len(hyps)): + first_tokens, is_first_token = get_first_tokens(hyps[utt]) + all_tokens = sp.id_to_piece(hyps[utt]) + index_pairs = parse_bpe_start_end_pairs(all_tokens, is_first_token) + words = sp.decode(first_tokens).split() + assert len(index_pairs) == len(words), ( + len(index_pairs), + len(words), + all_tokens, + ) + start = convert_timestamp( + frames=[i[0] for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in index_pairs], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + utt_words.append(words) + + return utt_time_pairs, utt_words def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py new_hyp: List[int] = [] - time: List[int] = [] + time: List[Tuple[int, int]] = [] cur = 0 + start, end = -1, -1 while cur < len(hyp): if hyp[cur] != 0: new_hyp.append(hyp[cur]) - time.append(cur) + start = cur prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: + if start != -1: + end = cur cur += 1 + if start != -1 and end != -1: + time.append((start, end)) + start, end = -1, -1 return new_hyp, time @@ -403,13 +462,9 @@ def decode_one_batch( # nnet_output is (N, T, C) if params.decoding_method == "ctc-greedy-search": - hyps, timestamps, _ = ctc_greedy_search( - nnet_output, - encoder_out_lens, - ) - res = DecodingResults(hyps=hyps, timestamps=timestamps) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = ctc_greedy_search( + ctc_probs=nnet_output, + nnet_output_lens=encoder_out_lens, sp=bpe_model, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, @@ -421,7 +476,7 @@ def decode_one_batch( ( supervisions["sequence_idx"], supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, + encoder_out_lens.cpu(), ), 1, ).to(torch.int32) @@ -434,75 +489,6 @@ def decode_one_batch( assert bpe_model is not None decoding_graph = H - if params.decoding_method in ["1best", "nbest", "nbest-oracle"]: - hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0] - - ori_scores = decoding_graph.scores.clone() - - ans = {} - for hlg_scale in hlg_scale_list: - decoding_graph.scores = ori_scores * hlg_scale - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - key_suffix = f"-HLG-scale-{hlg_scale}" - - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa - timestamps = [[] for _ in range(len(hyps))] - ans[key + key_suffix] = (hyps, timestamps) - - elif params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, - use_double_scores=params.use_double_scores, - ) - key = "no-rescore" - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, - subsampling_factor=params.subsampling_factor, - frame_shift_ms=params.frame_shift_ms, - word_table=word_table, - ) - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - timestamps = [[] for _ in range(len(hyps))] - - ans[key + key_suffix] = (hyps, timestamps) - - return ans - lattice = get_lattice( nnet_output=nnet_output, decoding_graph=decoding_graph, @@ -518,13 +504,8 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, sp=bpe_model, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, @@ -532,6 +513,50 @@ def decode_one_batch( key = "ctc-decoding" return {key: (hyps, timestamps)} + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa + return {key: (hyps, timestamps)} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = f"no_rescore_hlg_scale_{params.hlg_scale}" + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, + word_table=word_table, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + return {key: (hyps, timestamps)} + assert params.decoding_method in [ "nbest-rescoring", "whole-lattice-rescoring", @@ -581,7 +606,18 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, + List[ + Tuple[ + str, + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], +]: """Decode dataset. Args: @@ -632,7 +668,7 @@ def decode_dataset( time = [] if s.alignment is not None and "word" in s.alignment: time = [ - aliword.start + (aliword.start, aliword.end) for aliword in s.alignment["word"] if aliword.symbol != "" ] @@ -678,27 +714,35 @@ def save_results( test_set_name: str, results_dict: Dict[ str, - List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + List[ + Tuple[ + List[str], + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], ], ): test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + with_end_time=True, ) test_set_wers[key] = wer test_set_delays[key] = (mean_delay, var_delay) @@ -706,24 +750,22 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + # sort according to the mean start symbol delay + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: - print("settings\tsymbol-delay", file=f) + print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) for key, val in test_set_delays: print( - "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + "{}\tmean: {}, variance: {}".format(key, val[0], val[1]), file=f, ) @@ -734,10 +776,12 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: - s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note) note = "" logging.info(s) @@ -828,6 +872,7 @@ def main(): ) assert HLG.requires_grad is False + HLG.scores *= params.hlg_scale if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index ac489af9e..2cd223945 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -382,7 +382,7 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, # parameters for loss "beam_size": 10, - "reduction": "sum", + "reduction": "none", "use_double_scores": True, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 365e8b8a7..7be3299f3 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -432,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -453,9 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index c93125c80..e5a7c7116 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -750,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -770,9 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 78e1f4096..d022d463e 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -432,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -453,9 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py index f0c92a9b4..4a844b79f 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -169,9 +169,11 @@ class ConvolutionModule(nn.Module): channels: int, kernel_size: int, bias: bool = True, + is_pnnx: bool = True, ) -> None: """Construct an ConvolutionModule object.""" super().__init__() + self.is_pnnx = is_pnnx # kernerl_size should be an odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0, kernel_size @@ -383,12 +385,14 @@ class ConvolutionModule(nn.Module): - output right_context of shape (R, B, D). - updated cache tensor of shape (B, D, cache_size). """ - # U, B, D = utterance.size() - # R, _, _ = right_context.size() - U = self.chunk_length - B = 1 - D = self.channels - R = self.right_context_length + if self.is_pnnx is False: + U, B, D = utterance.size() + R, _, _ = right_context.size() + else: + U = self.chunk_length + B = 1 + D = self.channels + R = self.right_context_length # point-wise conv x = torch.cat([utterance, right_context], dim=0) # (U + R, B, D) @@ -448,8 +452,10 @@ class EmformerAttention(nn.Module): dropout: float = 0.0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + is_pnnx: bool = True, ): super().__init__() + self.is_pnnx = is_pnnx if embed_dim % nhead != 0: raise ValueError( @@ -539,14 +545,15 @@ class EmformerAttention(nn.Module): left_context_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Underlying chunk-wise attention implementation.""" - # U, B, _ = utterance.size() - # R = right_context.size(0) - # M = memory.size(0) - - U = self.chunk_length - B = 1 - R = self.right_context_length - M = self.memory_size + if self.is_pnnx is False: + U, B, _ = utterance.size() + R = right_context.size(0) + M = memory.size(0) + else: + U = self.chunk_length + B = 1 + R = self.right_context_length + M = self.memory_size L = self.left_context_length scaling = float(self.head_dim) ** -0.5 @@ -570,21 +577,29 @@ class EmformerAttention(nn.Module): # KV = key.size(0) - reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) - reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( - 1, 0, 2 - ) - reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute( - 1, 0, 2 - ) - - # reshaped_query, reshaped_key, reshaped_value = [ - # tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) - # for tensor in [query, key, value] - # ] # (B * nhead, Q or KV, head_dim) - attention_weights = torch.bmm( - reshaped_query * scaling, reshaped_key.permute(0, 2, 1) - ) # (B * nhead, Q, KV) + if self.is_pnnx is True: + reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) + reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( + 1, 0, 2 + ) + reshaped_value = value.view( + M + R + U + L, self.nhead, self.head_dim + ).permute(1, 0, 2) + else: + reshaped_query, reshaped_key, reshaped_value = [ + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) + for tensor in [query, key, value] + ] # (B * nhead, Q or KV, head_dim) + if self.is_pnnx is True: + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.permute(0, 2, 1) + ) # (B * nhead, Q, KV) + else: + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.transpose(1, 2) + ) # (B * nhead, Q, KV) # compute attention probabilities if False: @@ -597,10 +612,15 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) - # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) - # We have to change InnerProduct in ncnn to ignore the extra dim below - attention = attention.unsqueeze(1) + if self.is_pnnx is True: + attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) + # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) + # We have to change InnerProduct in ncnn to ignore the extra dim below + attention = attention.unsqueeze(1) + else: + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -733,15 +753,16 @@ class EmformerAttention(nn.Module): - attention value of left context and utterance, which would be cached for next computation, with shape (L + U, B, D). """ - # U = utterance.size(0) - # R = right_context.size(0) - # L = left_context_key.size(0) - # M = memory.size(0) - - U = self.chunk_length - R = self.right_context_length - L = self.left_context_length - M = self.memory_size + if self.is_pnnx is False: + U = utterance.size(0) + R = right_context.size(0) + L = left_context_key.size(0) + M = memory.size(0) + else: + U = self.chunk_length + R = self.right_context_length + L = self.left_context_length + M = self.memory_size # query = [right context, utterance] Q = R + U @@ -811,6 +832,7 @@ class EmformerEncoderLayer(nn.Module): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + is_pnnx: bool = True, ): super().__init__() @@ -824,6 +846,7 @@ class EmformerEncoderLayer(nn.Module): dropout=dropout, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + is_pnnx=is_pnnx, ) self.summary_op = nn.AvgPool1d( kernel_size=chunk_length, stride=chunk_length, ceil_mode=True @@ -850,6 +873,7 @@ class EmformerEncoderLayer(nn.Module): right_context_length, d_model, cnn_module_kernel, + is_pnnx=is_pnnx, ) self.norm_final = BasicNorm(d_model) @@ -1204,6 +1228,7 @@ class EmformerEncoder(nn.Module): memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + is_pnnx: bool = True, ): super().__init__() @@ -1229,6 +1254,7 @@ class EmformerEncoder(nn.Module): memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + is_pnnx=is_pnnx, ) for layer_idx in range(num_encoder_layers) ] @@ -1561,6 +1587,20 @@ class Emformer(EncoderInterface): self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx) self.is_pnnx = is_pnnx + self.num_encoder_layers = num_encoder_layers + self.memory_size = memory_size + self.d_model = d_model + self.cnn_module_kernel = cnn_module_kernel + self.left_context_length = left_context_length // subsampling_factor + self.right_context_length = right_context_length + self.subsampling_factor = subsampling_factor + + assert subsampling_factor == 4, subsampling_factor + pad_length = right_context_length + 2 * 4 + 3 + + # before subsampling + self.T = self.chunk_length + pad_length + self.encoder = EmformerEncoder( chunk_length=chunk_length // subsampling_factor, d_model=d_model, @@ -1575,6 +1615,7 @@ class Emformer(EncoderInterface): memory_size=memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + is_pnnx=is_pnnx, ) def _forward( @@ -1691,7 +1732,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, - is_pnnx: bool = False, + is_pnnx: bool = True, ) -> None: """ Args: @@ -1767,7 +1808,7 @@ class Conv2dSubsampling(nn.Module): x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - if torch.jit.is_tracing() and self.is_pnnx: + if torch.jit.is_tracing() and self.is_pnnx is True: x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) x = self.out(x) else: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 64c16141c..8fbb02f14 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 """ +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + Usage: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ @@ -44,7 +48,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import setup_logger, str2bool def get_parser(): @@ -96,14 +100,6 @@ def get_parser(): help="Path to the BPE model", ) - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -217,6 +213,8 @@ def main(): device = torch.device("cpu") + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + logging.info(f"device: {device}") sp = spm.SentencePieceProcessor() @@ -312,6 +310,16 @@ def main(): model.eval() convert_scaled_to_non_scaled(model, inplace=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + logging.info("Using torch.jit.trace()") logging.info("Exporting encoder") @@ -330,5 +338,4 @@ def main(): if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py new file mode 100755 index 000000000..ad0b45bd9 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model +from emformer import Emformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Emformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Emformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Emformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Emformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - a list of states (each layers has 4 states) + + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - a list of states (each layers has 4 states) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + memory_size = encoder_model.encoder.memory_size + cnn_module_kernel = encoder_model.encoder.cnn_module_kernel + chunk_length = encoder_model.encoder.chunk_length + right_context_length = encoder_model.encoder.right_context_length + encoder_dim = encoder_model.encoder.d_model + left_context_length = encoder_model.encoder.left_context_length + + T = encoder_model.encoder.T + + logging.info(f"num_encoder_layers={num_encoder_layers}") + logging.info(f"memory_size={memory_size}") + logging.info(f"cnn_module_kernel={cnn_module_kernel}") + logging.info(f"chunk_length={chunk_length}") + logging.info(f"right_context_length={right_context_length}") + logging.info(f"encoder_dim={encoder_dim}") + logging.info(f"left_context_length={left_context_length} (after subsampling)") + logging.info(f"T={T}") + + meta_data = { + "model_type": "conv-emformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(chunk_length), # 32 + "T": str(T), # 32 + "num_encoder_layers": str(num_encoder_layers), + "memory_size": str(memory_size), + "cnn_module_kernel": str(cnn_module_kernel), + "right_context_length": str(right_context_length), + "left_context_length": str(left_context_length), + "encoder_dim": str(encoder_dim), + } + logging.info(f"meta_data: {meta_data}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.init_states() + + # Each layer has 4 states + assert len(states) == num_encoder_layers * 4, (len(states), num_encoder_layers) + # layer 0: + # state0: (memory_size, 1, encoder_dim) + # state1: (left_context_length, 1, encoder_dim) + # state2: (left_context_length, 1, encoder_dim) + # state3: (1, encoder_dim, cnn_module_kernel-1) + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(s, name): + assert len(s) == 4, len(s) + logging.info(f"{name}_0.shape: {s[0].shape}") + input_names.append(f"{name}_0") + inputs[f"{name}_0"] = {1: "N"} + output_names.append(f"new_{name}_0") + + logging.info(f"{name}_1.shape: {s[1].shape}") + input_names.append(f"{name}_1") + inputs[f"{name}_1"] = {1: "N"} + output_names.append(f"new_{name}_1") + + logging.info(f"{name}_2.shape: {s[2].shape}") + input_names.append(f"{name}_2") + inputs[f"{name}_2"] = {1: "N"} + output_names.append(f"new_{name}_2") + + logging.info(f"{name}_3.shape: {s[3].shape}") + input_names.append(f"{name}_3") + inputs[f"{name}_3"] = {0: "N"} + output_names.append(f"new_{name}_3") + + for i in range(num_encoder_layers): + base_name = f"layer{i}" + s = states[i * 4 : (i + 1) * 4] + build_inputs_outputs(s, base_name) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.is_pnnx = False + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index 949214aec..b53426c75 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -64,6 +64,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -258,6 +259,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..db92ac696 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" + +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./conv_emformer_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./conv_emformer_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "conv-emformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + memory_size = int(encoder_meta["memory_size"]) + cnn_module_kernel = int(encoder_meta["cnn_module_kernel"]) + right_context_length = int(encoder_meta["right_context_length"]) + left_context_length = int(encoder_meta["left_context_length"]) + encoder_dim = int(encoder_meta["encoder_dim"]) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"memory_size: {memory_size}") + logging.info(f"cnn_module_kernel: {cnn_module_kernel}") + logging.info(f"left_context_length: {left_context_length} (after subsampling)") + logging.info(f"right_context_length: {right_context_length}") + logging.info(f"encoder_dim: {encoder_dim}") + + N = batch_size + + states = [] + for i in range(num_encoder_layers): + s0 = torch.zeros(memory_size, N, encoder_dim) + s1 = torch.zeros(left_context_length, N, encoder_dim) + s2 = torch.zeros(left_context_length, N, encoder_dim) + s3 = torch.zeros(N, encoder_dim, cnn_module_kernel - 1) + states.extend([s0, s1, s2, s3]) + + self.states = states + + self.segment = T + self.offset = decode_chunk_len + self.num_encoder_layers = num_encoder_layers + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_inputs_outputs(states: List[torch.Tensor], name: str): + for i in range(4): + if isinstance(states[i], torch.Tensor): + encoder_input[f"{name}_{i}"] = states[i].numpy() + else: + encoder_input[f"{name}_{i}"] = states[i] + + encoder_output.append(f"new_{name}_{i}") + + for i in range(self.num_encoder_layers): + base_name = f"layer{i}" + s = self.states[i * 4 : (i + 1) * 4] + build_inputs_outputs(s, base_name) + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py index e4104a5bb..74da9e6c8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -203,11 +203,8 @@ class Model: # (1, 512, 2) -> (512, 2) ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone()) - import pdb - - # pdb.set_trace() ret, ncnn_out0 = ex.extract("out0") - # assert ret == 0, ret + assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() out_states: List[torch.Tensor] = [] diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index b2cb2c96b..f5d894a7b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -750,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -770,9 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py index c91f94876..dd0a60736 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -425,6 +425,7 @@ def get_params() -> AttributeDict: "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate + "is_pnnx": True, "env_info": get_env_info(), } ) @@ -446,6 +447,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: left_context_length=params.left_context_length, right_context_length=params.right_context_length, memory_size=params.memory_size, + is_pnnx=params.is_pnnx, ) return encoder diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..745eaf1e8 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -54,10 +54,20 @@ def get_args(): help="""Path to the bpe.model. If not None, we will remove short and long utterances before extracting features""", ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + return parser.parse_args() -def compute_fbank_librispeech(bpe_model: Optional[str] = None): +def compute_fbank_librispeech( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -68,15 +78,19 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): sp = spm.SentencePieceProcessor() sp.load(bpe_model) - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + prefix = "librispeech" suffix = "jsonl.gz" manifests = read_manifests_if_cached( @@ -131,4 +145,4 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_librispeech(bpe_model=args.bpe_model) + compute_fbank_librispeech(bpe_model=args.bpe_model, dataset=args.dataset) diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index f620b91ea..de49f5321 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -35,6 +35,7 @@ from pathlib import Path from lhotse import CutSet, load_manifest_lazy from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr def get_args(): @@ -55,16 +56,22 @@ def validate_one_supervision_per_cut(c: Cut): def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' s = c.supervisions[0] - if s.start < c.start: - raise ValueError( - f"{c.id}: Supervision start time {s.start} is less " - f"than cut start time {c.start}" - ) - if s.end > c.end: + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger " + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " f"than cut end time {c.end}" ) @@ -83,6 +90,12 @@ def main(): validate_one_supervision_per_cut(c) validate_supervision_and_cut_time_bounds(c) + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 3ad08f56a..856c9d945 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -566,18 +566,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -587,9 +583,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py new file mode 120000 index 000000000..9f5064deb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export-onnx.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py new file mode 120000 index 000000000..0b1ea0326 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_check.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py new file mode 120000 index 000000000..099c2882f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/onnx_pretrained.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index 961d8ddfb..f989d9bc0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -742,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -762,9 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 78be9c01f..6c58a57e1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -702,18 +702,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -723,9 +719,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py new file mode 100755 index 000000000..08bfcb204 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export via torch.jit.trace() + +./lstm_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + +cd ./lstm_transducer_stateless2/exp +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + params.is_pnnx = True + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py new file mode 100755 index 000000000..46873ebf9 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -0,0 +1,593 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Optional, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from lstm import RNN +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for RNN and the encoder_proj from the joiner""" + + def __init__(self, encoder: RNN, encoder_proj: nn.Linear): + """ + Args: + encoder: + An RNN encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of RNN.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - updated states, whose shape is the same as the input states. + """ + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device) + encoder_out, _, next_states = self.encoder(x, x_lens, states) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, next_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + d_model = encoder_model.encoder.d_model + rnn_hidden_size = encoder_model.encoder.rnn_hidden_size + + decode_chunk_len = 4 + T = 9 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.get_init_states() + # state0: (num_encoder_layers, batch_size, d_model) + # state1: (num_encoder_layers, batch_size, rnn_hidden_size) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "state0", "state1"], + output_names=["encoder_out", "new_state0", "new_state1"], + dynamic_axes={ + "x": {0: "N"}, + "state0": {1: "N"}, + "state1": {1: "N"}, + "encoder_out": {0: "N"}, + "new_state0": {1: "N"}, + "new_state1": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "lstm", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": str(num_encoder_layers), + "d_model": str(d_model), + "rnn_hidden_size": str(rnn_hidden_size), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 5977cb36d..0adc68112 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -74,29 +74,6 @@ with the following commands: git lfs install git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 # You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp - -(3) Export to ONNX format - -./lstm_transducer_stateless2/export.py \ - --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./streaming-onnx-decode.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. """ import argparse @@ -192,35 +169,6 @@ def get_parser(): """, ) - parser.add_argument( - "--pnnx", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.trace for later - converting to PNNX. It will generate 3 files: - - encoder_jit_trace-pnnx.pt - - decoder_jit_trace-pnnx.pt - - joiner_jit_trace-pnnx.pt - """, - ) - - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit and --pnnx are ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -305,209 +253,6 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has 3 inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - states: a tuple containing: - - h0: a tensor of shape (num_layers, N, proj_size) - - c0: a tensor of shape (num_layers, N, hidden_size) - - and it has 3 outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - states: a tuple containing: - - next_h0: a tensor of shape (num_layers, N, proj_size) - - next_c0: a tensor of shape (num_layers, N, hidden_size) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - N = 1 - x = torch.zeros(N, 9, 80, dtype=torch.float32) - x_lens = torch.tensor([9], dtype=torch.int64) - h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) - - warmup = 1.0 - torch.onnx.export( - encoder_model, # use torch.jit.trace() internally - (x, x_lens, (h, c), warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "h", "c", "warmup"], - output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "h": {1: "N"}, - "c": {1: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - "next_h": {1: "N"}, - "next_c": {1: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "projected_encoder_out", - "projected_decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "projected_encoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -531,10 +276,6 @@ def main(): logging.info(params) - if params.pnnx: - params.is_pnnx = params.pnnx - logging.info("For PNNX") - logging.info("About to create model") model = get_transducer_model(params, enable_giga=False) @@ -629,44 +370,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx: - logging.info("Export model to ONNX format") - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - opset_version = 11 - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - - elif params.pnnx: - convert_scaled_to_non_scaled(model, inplace=True) - logging.info("Using torch.jit.trace()") - encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" - export_encoder_model_jit_trace(model.encoder, encoder_filename) - - decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" - export_decoder_model_jit_trace(model.decoder, decoder_filename) - - joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" - export_joiner_model_jit_trace(model.joiner, joiner_filename) - elif params.jit_trace is True: + if params.jit_trace is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace.pt" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3bd1b0a09..3eeaa5397 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -19,7 +19,7 @@ """ Usage: ./lstm_transducer_stateless2/ncnn-decode.py \ - --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ @@ -27,15 +27,19 @@ Usage: --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ ./test_wavs/1089-134686-0001.wav + +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for details. """ import argparse import logging from typing import List +import k2 import kaldifeat import ncnn -import sentencepiece as spm import torch import torchaudio @@ -44,9 +48,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--bpe-model-filename", + "--tokens", type=str, - help="Path to bpe.model", + help="Path to tokens.txt", ) parser.add_argument( @@ -240,9 +244,6 @@ def main(): model = Model(args) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model_filename) - sound_file = args.sound_filename sample_rate = 16000 @@ -280,8 +281,16 @@ def main(): encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) hyp = greedy_search(model, encoder_out) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + logging.info(sound_file) - logging.info(sp.decode(hyp)) + logging.info(text) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py new file mode 100755 index 000000000..c83f38b2a --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" + +cd exp +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./lstm_transducer_stateless2/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit-trace 1 + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./lstm_transducer_stateless2/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +""" + +import argparse +import logging + +from onnx_pretrained import OnnxModel + +from icefall import is_module_available + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = torch_encoder_model.get_init_states(N) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..fb9e121e5 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" +cd exp +ln -s exp/pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./lstm_transducer_stateless2/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1221-135766-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "lstm", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = int(encoder_meta["num_encoder_layers"]) + d_model = int(encoder_meta["d_model"]) + rnn_hidden_size = int(encoder_meta["rnn_hidden_size"]) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"d_model: {d_model}") + logging.info(f"rnn_hidden_size: {rnn_hidden_size}") + + N = batch_size + + s0 = torch.zeros(num_encoder_layers, N, d_model) + s1 = torch.zeros(num_encoder_layers, N, rnn_hidden_size) + states = [s0.numpy(), s1.numpy()] + + self.states = states + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = { + "x": x.numpy(), + "state0": self.states[0], + "state1": self.states[1], + } + encoder_output = ["encoder_out", "new_state0", "new_state1"] + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-3)//2-1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index 02ed16a8c..cbbc77928 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -16,13 +16,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for usage +""" import argparse import logging from typing import List, Optional +import k2 import ncnn -import sentencepiece as spm import torch import torchaudio from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -32,9 +37,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--bpe-model-filename", + "--tokens", type=str, - help="Path to bpe.model", + help="Path to tokens.txt", ) parser.add_argument( @@ -251,9 +256,6 @@ def main(): model = Model(args) - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model_filename) - sound_file = args.sound_filename sample_rate = 16000 @@ -329,10 +331,16 @@ def main(): model, encoder_out.squeeze(0), decoder_out, hyp ) + symbol_table = k2.SymbolTable.from_file(args.tokens) + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() logging.info(sound_file) - logging.info(sp.decode(hyp[context_size:])) + logging.info(text) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index b7953e5e3..a2b4f9e1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -94,10 +94,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./lstm_transducer_stateless3/decode.py \ @@ -484,7 +481,6 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, @@ -615,18 +611,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -637,9 +629,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -647,8 +637,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py new file mode 120000 index 000000000..d56cff73f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-for-ncnn.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py new file mode 120000 index 000000000..9f5064deb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-onnx.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-onnx.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..59a835d35 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -121,6 +121,8 @@ class RNN(EncoderInterface): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. """ def __init__( @@ -135,6 +137,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -148,7 +151,13 @@ class RNN(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx self.num_encoder_layers = num_encoder_layers self.d_model = d_model @@ -216,7 +225,13 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() @@ -377,7 +392,7 @@ class RNNEncoderLayer(nn.Module): # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) + src = self.dropout(src_lstm) + src # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -523,6 +538,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + is_pnnx: bool = False, ) -> None: """ Args: @@ -535,6 +551,9 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. """ assert in_channels >= 9 super().__init__() @@ -577,6 +596,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -590,9 +613,15 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py new file mode 120000 index 000000000..0b1ea0326 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_check.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py new file mode 120000 index 000000000..099c2882f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/onnx_pretrained.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 109746ed5..c737e3611 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -742,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -762,9 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index f56b4fd83..6ef4c9860 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dim", type=int, default=512, - help="Encoder output dimesion.", + help="Encoder output dimension.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( @@ -395,14 +416,10 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + "is_pnnx": False, } ) @@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 40d14bb5a..82fd103ea 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -386,17 +386,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -406,9 +402,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 0e3b7ff74..072d49d9c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -420,18 +420,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -441,9 +437,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 0444afe40..6dfe11cee 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -585,18 +585,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -606,9 +602,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index fbc39fb65..f4b01fd06 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -423,9 +423,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -434,9 +432,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -446,9 +442,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7388af389..bd2d6e258 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2377,6 +2377,6 @@ def modified_beam_search_lm_shallow_fusion( return ans else: return DecodingResults( - tokens=ans, + hyps=ans, timestamps=ans_timestamps, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 5f135f219..172c9ab7c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -609,18 +609,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -630,9 +626,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 1954f4724..9f88bd029 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -56,8 +56,6 @@ class Joiner(nn.Module): """ if not is_jit_tracing(): assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index bb08246d9..9c4a13606 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -425,9 +425,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -436,9 +434,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -448,9 +444,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 109a94a69..aa055049e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -869,18 +869,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -890,9 +886,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py new file mode 100755 index 000000000..ca8be307c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 239bdc12f..f30c9df6a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -52,32 +52,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, It will generates 3 files: `encoder_jit_trace.pt`, `decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. - -(3) Export to ONNX format - -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./onnx_pretrained.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(4) Export `model.state_dict()` +(3) Export `model.state_dict()` ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ @@ -210,23 +185,6 @@ def get_parser(): """, ) - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -370,206 +328,6 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - warmup = 1.0 - torch.onnx.export( - encoder_model, - (x, x_lens, warmup), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens", "warmup"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "projected_encoder_out", - "projected_decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "projected_encoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -636,31 +394,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx is True: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 11 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - elif params.jit is True: + if params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 163d737e3..5ca4173c1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -19,21 +19,70 @@ """ This script checks that exported onnx models produce the same output with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model via torchscript (torch.jit.script()) + +./pruned_transducer_stateless3/export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --jit 1 + +It will generate the following file in $repo/exp: + - cpu_jit.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx """ import argparse import logging from icefall import is_module_available +from onnx_pretrained import OnnxModel -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - -import onnxruntime as ort import torch -ort.set_default_logger_severity(3) - def get_parser(): parser = argparse.ArgumentParser( @@ -68,163 +117,81 @@ def get_parser(): help="Path to the onnx joiner model", ) - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - return parser def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] + C = 80 + for i in range(3): + N = torch.randint(low=1, high=20, size=(1,)).item() + T = torch.randint(low=30, high=50, size=(1,)).item() + logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] + x = torch.rand(N, T, C) + x_lens = torch.randint(low=30, high=T + 1, size=(N,)) + x_lens[0] = T - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T + torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) + torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) + onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() ) def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, + torch_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, ): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] + encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] + decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] + projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) + projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 512] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 512) - projected_decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() + torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out ) - # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) - - # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) @torch.no_grad() @@ -232,48 +199,38 @@ def main(): args = get_parser().parse_args() logging.info(vars(args)) - model = torch.jit.load(args.jit_filename) + torch_model = torch.jit.load(args.jit_filename) - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - ) - test_encoder(model, encoder_session) + test_encoder(torch_model, onnx_model) logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - ) - test_decoder(model, decoder_session) + test_decoder(torch_model, onnx_model) logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) + test_joiner(torch_model, onnx_model) logging.info("Finished checking ONNX models") +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) if __name__ == "__main__": torch.manual_seed(20220727) formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 550cf6aad..5adb6c16a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -18,35 +18,61 @@ This script loads ONNX models and uses them to decode waves. You can use the following command to get the exported models: -./pruned_transducer_stateless3/export.py \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +as an example to show how to use this file. -Usage of this script: +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt" + +cd exp +ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless3/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-9999-avg-1.onnx + - decoder-epoch-9999-avg-1.onnx + - joiner-epoch-9999-avg-1.onnx + +3. Run this file ./pruned_transducer_stateless3/onnx_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav + --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav """ import argparse import logging import math -from typing import List +from typing import List, Tuple +import k2 import kaldifeat import numpy as np import onnxruntime as ort -import sentencepiece as spm import torch import torchaudio from torch.nn.utils.rnn import pad_sequence @@ -79,23 +105,9 @@ def get_parser(): ) parser.add_argument( - "--joiner-encoder-proj-model-filename", + "--tokens", type=str, - required=True, - help="Path to the joiner encoder_proj onnx model. ", - ) - - parser.add_argument( - "--joiner-decoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner decoder_proj onnx model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -115,16 +127,122 @@ def get_parser(): help="The sample rate of the input sound file", ) - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - return parser +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def run_encoder( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 2-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, its shape is (N, T', joiner_dim) + - encoder_out_lens, its shape is (N,) + """ + out = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]), torch.from_numpy(out[1]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -149,36 +267,22 @@ def read_sound_files( def greedy_search( - decoder: ort.InferenceSession, - joiner: ort.InferenceSession, - joiner_encoder_proj: ort.InferenceSession, - joiner_decoder_proj: ort.InferenceSession, - encoder_out: np.ndarray, - encoder_out_lens: np.ndarray, - context_size: int, + model: OnnxModel, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: - decoder: - The decoder model. - joiner: - The joiner model. - joiner_encoder_proj: - The joiner encoder projection model. - joiner_decoder_proj: - The joiner decoder projection model. + model: + The transducer model. encoder_out: - A 3-D tensor of shape (N, T, C) + A 3-D tensor of shape (N, T, joiner_dim) encoder_out_lens: A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. Returns: Return the decoded results for each utterance. """ - encoder_out = torch.from_numpy(encoder_out) - encoder_out_lens = torch.from_numpy(encoder_out_lens) - assert encoder_out.ndim == 3 + assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( @@ -188,11 +292,6 @@ def greedy_search( enforce_sorted=False, ) - projected_encoder_out = joiner_encoder_proj.run( - [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, - )[0] - blank_id = 0 # hard-code to 0 batch_size_list = packed_encoder_out.batch_sizes.tolist() @@ -201,50 +300,27 @@ def greedy_search( assert torch.all(encoder_out_lens > 0), encoder_out_lens assert N == batch_size_list[0], (N, batch_size_list) + context_size = model.context_size hyps = [[blank_id] * context_size for _ in range(N)] - decoder_input_nodes = decoder.get_inputs() - decoder_output_nodes = decoder.get_outputs() - - joiner_input_nodes = joiner.get_inputs() - joiner_output_nodes = joiner.get_outputs() - decoder_input = torch.tensor( hyps, dtype=torch.int64, ) # (N, context_size) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - - projected_decoder_out = torch.from_numpy(projected_decoder_out) + decoder_out = model.run_decoder(decoder_input) offset = 0 for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = projected_encoder_out[start:end] - # current_encoder_out's shape: (batch_size, encoder_out_dim) + current_encoder_out = packed_encoder_out.data[start:end] + # current_encoder_out's shape: (batch_size, joiner_dim) offset = end - projected_decoder_out = projected_decoder_out[:batch_size] + decoder_out = decoder_out[:batch_size] + logits = model.run_joiner(current_encoder_out, decoder_out) - logits = joiner.run( - [joiner_output_nodes[0].name], - { - joiner_input_nodes[0].name: current_encoder_out, - joiner_input_nodes[1].name: projected_decoder_out.numpy(), - }, - )[0] - logits = torch.from_numpy(logits).squeeze(1).squeeze(1) # logits'shape (batch_size, vocab_size) assert logits.ndim == 2, logits.shape @@ -261,17 +337,7 @@ def greedy_search( decoder_input, dtype=torch.int64, ) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - projected_decoder_out = torch.from_numpy(projected_decoder_out) + decoder_out = model.run_decoder(decoder_input) sorted_ans = [h[context_size:] for h in hyps] ans = [] @@ -287,39 +353,12 @@ def main(): parser = get_parser() args = parser.parse_args() logging.info(vars(args)) - - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - encoder = ort.InferenceSession( - args.encoder_model_filename, - sess_options=session_opts, + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, ) - decoder = ort.InferenceSession( - args.decoder_model_filename, - sess_options=session_opts, - ) - - joiner = ort.InferenceSession( - args.joiner_model_filename, - sess_options=session_opts, - ) - - joiner_encoder_proj = ort.InferenceSession( - args.joiner_encoder_proj_model_filename, - sess_options=session_opts, - ) - - joiner_decoder_proj = ort.InferenceSession( - args.joiner_decoder_proj_model_filename, - sess_options=session_opts, - ) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = "cpu" @@ -347,30 +386,27 @@ def main(): ) feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - - encoder_input_nodes = encoder.get_inputs() - encoder_out_nodes = encoder.get_outputs() - encoder_out, encoder_out_lens = encoder.run( - [encoder_out_nodes[0].name, encoder_out_nodes[1].name], - { - encoder_input_nodes[0].name: features.numpy(), - encoder_input_nodes[1].name: feature_lengths.numpy(), - }, - ) + encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) hyps = greedy_search( - decoder=decoder, - joiner=joiner, - joiner_encoder_proj=joiner_encoder_proj, - joiner_decoder_proj=joiner_decoder_proj, + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - context_size=args.context_size, ) s = "\n" + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += symbol_table[i] + return text.replace("▁", " ").strip() + + context_size = model.context_size for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" + words = token_ids_to_words(hyp[context_size:]) + s += f"{filename}:\n{words}\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 0e5111f33..3a1ecb7ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -426,18 +426,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -447,9 +443,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index f5cbc21f7..5ec3d3b45 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -109,10 +109,7 @@ Usage: To evaluate symbol delay, you should: (1) Generate cuts with word-time alignments: -./local/add_alignment_librispeech.py \ - --alignments-dir data/alignment \ - --cuts-in-dir data/fbank \ - --cuts-out-dir data/fbank_ali +./add_alignments.sh (2) Set the argument "--manifest-dir data/fbank_ali" while decoding. For example: ./pruned_transducer_stateless4/decode.py \ @@ -528,7 +525,6 @@ def decode_one_batch( res = DecodingResults(hyps=tokens, timestamps=timestamps) hyps, timestamps = parse_hyp_and_timestamp( - decoding_method=params.decoding_method, res=res, sp=sp, subsampling_factor=params.subsampling_factor, @@ -659,18 +655,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -681,9 +673,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -691,8 +681,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 401b3ef3a..8f33f5b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -261,6 +262,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py new file mode 120000 index 000000000..9aa06f82f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index c4e3cef16..ca3a023ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -442,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -463,9 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index b3a7d71bc..8bbceec61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -32,7 +32,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask class Conformer(EncoderInterface): @@ -1012,15 +1012,28 @@ class RelPositionMultiheadAttention(nn.Module): n == left_context + 2 * time1 - 1 ), f"{n} == {left_context} + 2 * {time1} - 1" # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + + if is_jit_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 90b0fcf4b..2be895feb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -735,18 +735,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -756,9 +752,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py new file mode 100755 index 000000000..743fe8a92 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt" + +cd exp +ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless5/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir $repo/exp \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Conformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Conformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Conformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py new file mode 120000 index 000000000..66d63b807 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 064811f1c..5b15dcee7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -442,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -463,9 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index fd9de052a..95534efef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -416,18 +416,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -437,9 +433,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index b9bce465f..32b3134b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -722,18 +722,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -743,9 +739,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 5f90e6375..384b78524 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -87,7 +87,11 @@ class Decoder(nn.Module): y = y.to(torch.int64) # this stuff about clamp() is a temporary fix for a mismatch # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + if torch.jit.is_tracing(): + # This is for exporting to PNNX via ONNX + embedding_out = self.embedding(y) + else: + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py new file mode 100755 index 000000000..f76915a74 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -0,0 +1,560 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" + +cd exp +ln -s pretrained-epoch-30-avg-9.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index db8b5eb2b..3e3160e7e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -41,31 +41,7 @@ Check https://github.com/k2-fsa/sherpa for how to use the exported models outside of icefall. -(2) Export to ONNX format - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Please see ./onnx_pretrained.py for usage of the generated files - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(3) Export `model.state_dict()` +(2) Export `model.state_dict()` ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ @@ -196,23 +172,6 @@ def get_parser(): """, ) - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - parser.add_argument( "--context-size", type=int, @@ -225,204 +184,6 @@ def get_parser(): return parser -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 101, 80, dtype=torch.float32) - x_lens = torch.tensor([101], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -531,31 +292,7 @@ def main(): model.to("cpu") model.eval() - if params.onnx is True: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 13 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - elif params.jit is True: + if params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py new file mode 100755 index 000000000..37edc0390 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. + +(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model True \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. + +(3) use the original model with checkpoint exp_dir/epoch-xxx.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --epoch 28 \ + --avg 15 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt +./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ + --iter 22000 \ + --avg 5 \ + --use-averaged-model False \ + --exp-dir ./pruned_transducer_stateless7/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path +from typing import Dict, List + +import sentencepiece as spm +import torch + +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model." + "If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + print("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + print(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt" + ) + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 3ddac2cf2..62a4d22d6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -53,7 +53,6 @@ class Joiner(nn.Module): """ assert encoder_out.ndim == decoder_out.ndim assert encoder_out.ndim in (2, 4) - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py deleted file mode 100755 index 63acc0922..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py +++ /dev/null @@ -1,286 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. -""" - -import argparse -import logging - -import onnxruntime as ort -import torch - -from icefall import is_module_available - -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - - -ort.set_default_logger_severity(3) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - - return parser - - -def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, -): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] - - for N in [1, 5]: - for T in [12, 50]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T - - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) - - -def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, -): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() - ) - - -def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, -): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] - - assert joiner_inputs[0].shape == ["N", 1, 1, 512] - assert joiner_inputs[1].shape == ["N", 1, 1, 512] - - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 384] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 384) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 1, 1, 512) - projected_decoder_out = torch.rand(N, 1, 1, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() - ) - - # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) - - # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - model = torch.jit.load(args.jit_filename) - - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 - - logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - ) - test_encoder(model, encoder_session) - - logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - ) - test_decoder(model, decoder_session) - - logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) - logging.info("Finished checking ONNX models") - - -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py new file mode 120000 index 000000000..20e334271 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_check.py @@ -0,0 +1 @@ +../pruned_transducer_stateless5/onnx_check.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py deleted file mode 100755 index 3a06ee293..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 \ - --onnx 1 - -Usage of this script: - -./pruned_transducer_stateless7/onnx_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless7/exp/encoder.onnx \ - --decoder-model-filename ./pruned_transducer_stateless7/exp/decoder.onnx \ - --joiner-model-filename ./pruned_transducer_stateless7/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7/exp/joiner_decoder_proj.onnx \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import kaldifeat -import numpy as np -import onnxruntime as ort -import sentencepiece as spm -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--joiner-encoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner encoder_proj onnx model. ", - ) - - parser.add_argument( - "--joiner-decoder-proj-model-filename", - type=str, - required=True, - help="Path to the joiner decoder_proj onnx model. ", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: ort.InferenceSession, - joiner: ort.InferenceSession, - joiner_encoder_proj: ort.InferenceSession, - joiner_decoder_proj: ort.InferenceSession, - encoder_out: np.ndarray, - encoder_out_lens: np.ndarray, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - joiner_encoder_proj: - The joiner encoder projection model. - joiner_decoder_proj: - The joiner decoder projection model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - encoder_out = torch.from_numpy(encoder_out) - encoder_out_lens = torch.from_numpy(encoder_out_lens) - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - projected_encoder_out = joiner_encoder_proj.run( - [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, - )[0] - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input_nodes = decoder.get_inputs() - decoder_output_nodes = decoder.get_outputs() - - joiner_input_nodes = joiner.get_inputs() - joiner_output_nodes = joiner.get_outputs() - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - - projected_decoder_out = torch.from_numpy(projected_decoder_out) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = projected_encoder_out[start:end] - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - projected_decoder_out = projected_decoder_out[:batch_size] - - logits = joiner.run( - [joiner_output_nodes[0].name], - { - joiner_input_nodes[0].name: np.expand_dims( - np.expand_dims(current_encoder_out, axis=1), axis=1 - ), - joiner_input_nodes[1] - .name: projected_decoder_out.unsqueeze(1) - .unsqueeze(1) - .numpy(), - }, - )[0] - logits = torch.from_numpy(logits).squeeze(1).squeeze(1) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = decoder.run( - [decoder_output_nodes[0].name], - { - decoder_input_nodes[0].name: decoder_input.numpy(), - }, - )[0].squeeze(1) - projected_decoder_out = joiner_decoder_proj.run( - [joiner_decoder_proj.get_outputs()[0].name], - {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, - )[0] - projected_decoder_out = torch.from_numpy(projected_decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - encoder = ort.InferenceSession( - args.encoder_model_filename, - sess_options=session_opts, - ) - - decoder = ort.InferenceSession( - args.decoder_model_filename, - sess_options=session_opts, - ) - - joiner = ort.InferenceSession( - args.joiner_model_filename, - sess_options=session_opts, - ) - - joiner_encoder_proj = ort.InferenceSession( - args.joiner_encoder_proj_model_filename, - sess_options=session_opts, - ) - - joiner_decoder_proj = ort.InferenceSession( - args.joiner_decoder_proj_model_filename, - sess_options=session_opts, - ) - - sp = spm.SentencePieceProcessor() - sp.load(args.bpe_model) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - - encoder_input_nodes = encoder.get_inputs() - encoder_out_nodes = encoder.get_outputs() - encoder_out, encoder_out_lens = encoder.run( - [encoder_out_nodes[0].name, encoder_out_nodes[1].name], - { - encoder_input_nodes[0].name: features.numpy(), - encoder_input_nodes[1].name: feature_lengths.numpy(), - }, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - joiner_encoder_proj=joiner_encoder_proj, - joiner_decoder_proj=joiner_decoder_proj, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = sp.decode(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py new file mode 120000 index 000000000..7607623c8 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 56165d1f9..86067b04f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -22,11 +22,101 @@ BasicNorm is replaced by a module with `exp` removed. """ import copy -from typing import List +from typing import List, Tuple import torch import torch.nn as nn from scaling import ActivationBalancer, BasicNorm, Whiten +from zipformer import PoolingModule + + +class PoolingModuleNoProj(nn.Module): + def forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x = x.cumsum(dim=0) # (T, N, C) + x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) + # Cumulated numbers of frames from start + cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) + cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) + pooling_mask = (1.0 / cum_mask).unsqueeze(2) + # now pooling_mask: (T, N, 1) + x = x * pooling_mask # (T, N, C) + + cached_len = cached_len + x.size(0) + cached_avg = x[-1] + + return x, cached_len, cached_avg + + +class PoolingModuleWithProj(nn.Module): + def __init__(self, proj: torch.nn.Module): + super().__init__() + self.proj = proj + self.pooling = PoolingModuleNoProj() + + def forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) + return self.proj(x), cached_len, cached_avg + + def streaming_forward( + self, + x: torch.Tensor, + cached_len: torch.Tensor, + cached_avg: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (T, N, C) + cached_len: + A tensor of shape (N,) + cached_avg: + A tensor of shape (N, C) + Returns: + Return a tuple containing: + - new_x + - new_cached_len + - new_cached_avg + """ + x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) + return self.proj(x), cached_len, cached_avg class NonScaledNorm(nn.Module): @@ -53,7 +143,7 @@ class NonScaledNorm(nn.Module): def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) + assert isinstance(basic_norm, BasicNorm), type(basic_norm) norm = NonScaledNorm( num_channels=basic_norm.num_channels, eps_exp=basic_norm.eps.data.exp().item(), @@ -62,6 +152,11 @@ def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: return norm +def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj: + assert isinstance(pooling, PoolingModule), type(pooling) + return PoolingModuleWithProj(proj=pooling.proj) + + # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # get_submodule was added to nn.Module at v1.9.0 def get_submodule(model, target): @@ -83,6 +178,7 @@ def get_submodule(model, target): def convert_scaled_to_non_scaled( model: nn.Module, inplace: bool = False, + is_pnnx: bool = False, ): """ Args: @@ -91,6 +187,8 @@ def convert_scaled_to_non_scaled( inplace: If True, the input model is modified inplace. If False, the input model is copied and we modify the copied version. + is_pnnx: + True if we are going to export the model for PNNX. Return: Return a model without scaled layers. """ @@ -103,6 +201,8 @@ def convert_scaled_to_non_scaled( d[name] = convert_basic_norm(m) elif isinstance(m, (ActivationBalancer, Whiten)): d[name] = nn.Identity() + elif isinstance(m, PoolingModule) and is_pnnx: + d[name] = convert_pooling_module(m) for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..792a243e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -82,7 +82,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -420,6 +426,8 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -642,6 +650,17 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) @@ -1024,10 +1043,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b1717ec64..3959c0bb2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( from torch import Tensor, nn from icefall.dist import get_rank -from icefall.utils import make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask class Zipformer(EncoderInterface): @@ -197,13 +197,13 @@ class Zipformer(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape @@ -792,7 +792,8 @@ class AttentionDownsample(torch.nn.Module): src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -904,6 +905,13 @@ class RelPositionalEncoding(torch.nn.Module): def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if is_jit_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 self.d_model = d_model self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None @@ -1009,10 +1017,10 @@ class RelPositionMultiheadAttention(nn.Module): # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + 2 * attention_dim # query, key + + attention_dim // 2 # value + + pos_dim * num_heads # positional encoding query + ) self.in_proj = ScaledLinear( embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 @@ -1509,7 +1517,7 @@ class FeedforwardModule(nn.Module): class ConvolutionModule(nn.Module): """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Args: channels (int): The number of channels of conv layers. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 4b373e4c7..629bec058 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -541,18 +541,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -560,9 +556,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 32a9b6bb2..7641fa5af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 5a05e1836..718381baa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1072,10 +1072,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py old mode 100755 new mode 100644 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index f137485b2..fa7144f0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -532,18 +532,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -551,9 +547,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index 9c2166aaf..01ba7b711 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -21,7 +21,7 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -29,7 +29,7 @@ Usage: --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -38,7 +38,7 @@ Usage: --beam-size 4 (3) modified beam search -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -47,7 +47,7 @@ Usage: --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -58,7 +58,7 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ @@ -71,7 +71,7 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ @@ -84,7 +84,7 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \ +./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index ce45a4beb..e497787d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 96d316604..05df8cfff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -72,14 +72,14 @@ Check ./pretrained.py for its usage. Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 with the following commands: sudo apt-get install git-lfs git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py new file mode 100755 index 000000000..50efa6e60 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -0,0 +1,665 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to ONNX format + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + # You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc_bs/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=True, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - frame_reducer.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 2000, 80, dtype=torch.float32) + x_lens = torch.tensor([2000] * 15, dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_lconv_onnx( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - src_key_padding_mask: a tensor of shape (N, T) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool) + + torch.onnx.export( + lconv, + (lconv_input, src_key_padding_mask), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "src_key_padding_mask"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + +def export_frame_reducer_onnx( + frame_reducer: nn.Module, + frame_reducer_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has four inputs: + + - x: a tensor of shape (N, T, C) + - x_lens: a tensor of shape (N, T) + - ctc_output: a tensor of shape (N, T, vocab_size) + - blank_id: an int, always 0 + + and has two outputs: + + - x_fr: a tensor of shape (N, T, C) + - x_lens_fr: a tensor of shape (N, T) + + Args: + frame_reducer: + The frame_reducer to be exported. + frame_reducer_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(15, 498, 384, dtype=torch.float32) + x_lens = torch.tensor([498] * 15, dtype=torch.int64) + ctc_output = torch.randn(15, 498, 500, dtype=torch.float32) + + torch.onnx.export( + frame_reducer, + (x, x_lens, ctc_output), + frame_reducer_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "ctc_output"], + output_names=["out", "out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "ctc_output": {0: "N", 1: "T"}, + "out": {0: "N", 1: "T"}, + "out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {frame_reducer_filename}") + + +def export_ctc_output_onnx( + ctc_output: nn.Module, + ctc_output_filename: str, + opset_version: int = 11, +) -> None: + """Export the frame_reducer to ONNX format. + + The exported frame_reducer has one inputs: + + - encoder_out: a tensor of shape (N, T, C) + + and has one output: + + - ctc_output: a tensor of shape (N, T, vocab_size) + + Args: + ctc_output: + The ctc_output to be exported. + ctc_output_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32) + + torch.onnx.export( + ctc_output, + (encoder_out), + ctc_output_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["ctc_output"], + dynamic_axes={ + "encoder_out": {0: "N", 1: "T"}, + "ctc_output": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {ctc_output_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + lconv_filename = params.exp_dir / "lconv.onnx" + export_lconv_onnx( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + + frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" + export_frame_reducer_onnx( + model.frame_reducer, + frame_reducer_filename, + opset_version=opset_version, + ) + + ctc_output_filename = params.exp_dir / "ctc_output.onnx" + export_ctc_output_onnx( + model.ctc_output, + ctc_output_filename, + opset_version=opset_version, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py old mode 100755 new mode 100644 index 9fe88929d..0841f7cf1 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/frame_reducer.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, -# Zengwei Yao) +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -18,11 +19,12 @@ # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F + from icefall.utils import make_pad_mask @@ -43,6 +45,7 @@ class FrameReducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ctc_output: torch.Tensor, + y_lens: Optional[torch.Tensor] = None, blank_id: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -54,26 +57,110 @@ class FrameReducer(nn.Module): `x` before padding. ctc_output: The CTC output with shape [N, T, vocab_size]. - blank_id: - The ID of the blank symbol. - Returns: - x_fr: - The frame reduced encoder output with shape [N, T', C]. - x_lens_fr: + y_lens: A tensor of shape (batch_size,) containing the number of frames in - `x_fr` before padding. + `y` before padding. + blank_id: + The blank id of ctc_output. + Returns: + out: + The frame reduced encoder output with shape [N, T', C]. + out_lens: + A tensor of shape (batch_size,) containing the number of frames in + `out` before padding. """ + N, T, C = x.size() padding_mask = make_pad_mask(x_lens) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) - frames_list: List[torch.Tensor] = [] - lens_list: List[int] = [] - for i in range(x.shape[0]): - frames = x[i][non_blank_mask[i]] - frames_list.append(frames) - lens_list.append(frames.shape[0]) - x_fr = pad_sequence(frames_list, batch_first=True) - x_lens_fr = torch.tensor(lens_list).to(device=x.device) + if y_lens is not None: + # Limit the maximum number of reduced frames + limit_lens = T - y_lens + max_limit_len = limit_lens.max().int() + fake_limit_indexes = torch.topk( + ctc_output[:, :, blank_id], max_limit_len + ).indices + T = ( + torch.arange(max_limit_len) + .expand_as( + fake_limit_indexes, + ) + .to(device=x.device) + ) + T = torch.remainder(T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, T) + limit_mask = torch.full_like( + non_blank_mask, + False, + device=x.device, + ).scatter_(1, limit_indexes, True) - return x_fr, x_lens_fr + non_blank_mask = non_blank_mask | ~limit_mask + + out_lens = non_blank_mask.sum(dim=1) + max_len = out_lens.max() + pad_lens_list = ( + torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) + - out_lens + ) + max_pad_len = pad_lens_list.max() + + out = F.pad(x, (0, 0, 0, max_pad_len)) + + valid_pad_mask = ~make_pad_mask(pad_lens_list) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + + out = out[total_valid_mask].reshape(N, -1, C) + + return out, out_lens + + +if __name__ == "__main__": + import time + + test_times = 10000 + device = "cuda:0" + frame_reducer = FrameReducer() + + # non zero case + x = torch.ones(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.log( + torch.randn(15, 498, 500, dtype=torch.float32, device=device), + ) + + avg_time = 0 + for i in range(test_times): + torch.cuda.synchronize(device=x.device) + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) + + # all zero case + x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device) + + avg_time = 0 + for i in range(test_times): + torch.cuda.synchronize(device=x.device) + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py old mode 100755 new mode 100644 index bfd49d533..a902358ae --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py @@ -62,7 +62,7 @@ class LConv(nn.Module): kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, - groups=channels, + groups=2 * channels, bias=bias, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py old mode 100755 new mode 100644 index 86acc5a10..0582b289f --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -131,6 +131,10 @@ class Transducer(nn.Module): # compute ctc log-probs ctc_output = self.ctc_output(encoder_out) + # y_lens + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + # blank skip blank_id = self.decoder.blank_id @@ -146,16 +150,14 @@ class Transducer(nn.Module): encoder_out, x_lens, ctc_output, + y_lens, blank_id, ) else: encoder_out_fr = encoder_out x_lens_fr = x_lens - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - + # sos_y sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py new file mode 100755 index 000000000..8ff02fbcb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 + +Usage of this script: + +./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \ + --lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \ + --frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \ + --ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence + +from icefall.utils import make_pad_mask + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--lconv-filename", + type=str, + required=True, + help="Path to the lconv onnx model. ", + ) + + parser.add_argument( + "--frame-reducer-filename", + type=str, + required=True, + help="Path to the frame reducer onnx model. ", + ) + + parser.add_argument( + "--ctc-output-filename", + type=str, + required=True, + help="Path to the ctc_output onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: np.expand_dims( + np.expand_dims(current_encoder_out, axis=1), axis=1 + ), + joiner_input_nodes[1] + .name: projected_decoder_out.unsqueeze(1) + .unsqueeze(1) + .numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + lconv = ort.InferenceSession( + args.lconv_filename, + sess_options=session_opts, + ) + + frame_reducer = ort.InferenceSession( + args.frame_reducer_filename, + sess_options=session_opts, + ) + + ctc_output = ort.InferenceSession( + args.ctc_output_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + ctc_output_input_nodes = ctc_output.get_inputs() + ctc_output_out_nodes = ctc_output.get_outputs() + ctc_out = ctc_output.run( + [ctc_output_out_nodes[0].name], + { + ctc_output_input_nodes[0].name: encoder_out, + }, + )[0] + + lconv_input_nodes = lconv.get_inputs() + lconv_out_nodes = lconv.get_outputs() + encoder_out = lconv.run( + [lconv_out_nodes[0].name], + { + lconv_input_nodes[0].name: encoder_out, + lconv_input_nodes[1] + .name: make_pad_mask(torch.from_numpy(encoder_out_lens)) + .numpy(), + }, + )[0] + + frame_reducer_input_nodes = frame_reducer.get_inputs() + frame_reducer_out_nodes = frame_reducer.get_outputs() + encoder_out_fr, encoder_out_lens_fr = frame_reducer.run( + [frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name], + { + frame_reducer_input_nodes[0].name: encoder_out, + frame_reducer_input_nodes[1].name: encoder_out_lens, + frame_reducer_input_nodes[2].name: ctc_out, + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out_fr, + encoder_out_lens=encoder_out_lens_fr, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 522ecc974..ea280e642 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, # Mingshuang Luo, @@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --full-libri 1 \ - --max-duration 550 + --max-duration 750 """ @@ -55,9 +54,9 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder +from frame_reducer import FrameReducer from joiner import Joiner from lconv import LConv -from frame_reducer import FrameReducer from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed @@ -1063,10 +1062,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md index 6e461e196..d3691e647 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md @@ -1,3 +1,10 @@ This recipe implements Streaming Zipformer-Transducer model. See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials. + +[./emformer.py](./emformer.py) and [./train.py](./train.py) +are basically the same as +[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py). +The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py) +is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn). + diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index aebe2b94b..e7616fbc5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -568,18 +568,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -589,9 +585,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py new file mode 100755 index 000000000..1f870ca5a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py new file mode 100755 index 000000000..0f84eca83 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +""" +Please see +https://k2-fsa.github.io/icefall/model-export/export-ncnn.html +for more details about how to use this file. + +We use +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +to demonstrate the usage of this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export to ncnn + +./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir $repo/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +cd $repo/exp + +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + + decode_chunk_len = encoder_model.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length # 32 + 7 = 39 + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.get_init_state() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + setup_logger(f"{params.exp_dir}/log-export/log-export-ncnn") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_pnnx=True) + + encoder_num_param = sum([p.numel() for p in model.encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in model.decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in model.joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 100755 index 000000000..d7092403e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1,639 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from decoder import Decoder +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 5c06cc052..1bc54fa26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -72,25 +72,81 @@ Check ./pretrained.py for its usage. Note: If you don't want to train a model from scratch, we have provided one for you. You can get it at -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 with the following commands: sudo apt-get install git-lfs git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp + +(3) Export to ONNX format with pretrained.pt + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(4) Export to ONNX format for triton server + +cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp +ln -s pretrained.pt epoch-999.pt +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --use-averaged-model False \ + --epoch 999 \ + --avg 1 \ + --fp16 \ + --onnx-triton 1 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + +Check +https://github.com/k2-fsa/sherpa/tree/master/triton +for how to use the exported models outside of icefall. + """ + import argparse import logging from pathlib import Path +import onnxruntime import sentencepiece as spm import torch import torch.nn as nn +from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states from icefall.checkpoint import ( average_checkpoints, @@ -172,6 +228,42 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, --onnx would export model into the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton. + """, + ) + + parser.add_argument( + "--fp16", + action="store_true", + help="whether to export fp16 onnx model, default false", + ) + parser.add_argument( "--context-size", type=int, @@ -184,6 +276,391 @@ def get_parser(): return parser +def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + for a, b in zip(xlist, blist): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print("small mismatch detected", error) + else: + return False + return True + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + batch_size = 17 + seq_len = 101 + torch.manual_seed(0) + x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32) + x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + initial_states = [encoder_model.get_init_state() for _ in range(batch_size)] + states = stack_states(initial_states) + + left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks + encoder_attention_dim = encoder_model.encoders[0].attention_dim + + len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 + avg_cache = torch.cat( + states[encoder_model.num_encoders : 2 * encoder_model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dim - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders + ] + ] + attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + encoder_model_wrapper = OnnxStreamingEncoder(encoder_model) + + torch.onnx.export( + encoder_model_wrapper, + (x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "x", + "x_lens", + "len_cache", + "avg_cache", + "attn_cache", + "cnn_cache", + ], + output_names=[ + "encoder_out", + "encoder_out_lens", + "new_len_cache", + "new_avg_cache", + "new_attn_cache", + "new_cnn_cache", + ], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "len_cache": {0: "N"}, + "avg_cache": {0: "N"}, + "attn_cache": {0: "N"}, + "cnn_cache": {0: "N"}, + "new_len_cache": {0: "N"}, + "new_avg_cache": {0: "N"}, + "new_attn_cache": {0: "N"}, + "new_cnn_cache": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + # Test onnx encoder with torch native encoder + encoder_model.eval() + ( + encoder_out_torch, + encoder_out_lens_torch, + new_states_torch, + ) = encoder_model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + ort_session = onnxruntime.InferenceSession( + str(encoder_filename), providers=["CPUExecutionProvider"] + ) + ort_inputs = { + "x": x.numpy(), + "x_lens": x_lens.numpy(), + "len_cache": len_cache.numpy(), + "avg_cache": avg_cache.numpy(), + "attn_cache": attn_cache.numpy(), + "cnn_cache": cnn_cache.numpy(), + } + ort_outs = ort_session.run(None, ort_inputs) + + assert test_acc( + [encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2] + ) + logging.info(f"{encoder_filename} acc test succeeded.") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y,), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + and has one output: + - joiner_out: a tensor of shape (N, vocab_size) + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + joiner_model = TritonOnnxJoiner(joiner_model) + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -292,7 +769,87 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.onnx: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 13 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + if not params.onnx_triton: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + else: + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + if params.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except ImportError: + print("Please install onnxmltools!") + import sys + + sys.exit(1) + + def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_fp16_filename) + + decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx" + export_onnx_fp16(decoder_filename, decoder_fp16_filename) + + joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx" + export_onnx_fp16(joiner_filename, joiner_fp16_filename) + + if not params.onnx_triton: + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + encoder_proj_fp16_filename = ( + params.exp_dir / "joiner_encoder_proj_fp16.onnx" + ) + export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + decoder_proj_fp16_filename = ( + params.exp_dir / "joiner_decoder_proj_fp16.onnx" + ) + export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename) + + elif params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 100755 index 000000000..6c78ba70b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script checks that exported ONNX models produce the same output +with the given torchscript model for the same input. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model via torch.jit.trace() + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp + + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +3. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +4. Run this file + +./pruned_transducer_stateless7_streaming/onnx_check.py \ + --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ + --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ + --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx +""" + +import argparse +import logging + +from onnx_pretrained import OnnxModel +from zipformer import stack_states + +from icefall import is_module_available + +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-encoder-filename", + required=True, + type=str, + help="Path to the torchscript encoder model", + ) + + parser.add_argument( + "--jit-decoder-filename", + required=True, + type=str, + help="Path to the torchscript decoder model", + ) + + parser.add_argument( + "--jit-joiner-filename", + required=True, + type=str, + help="Path to the torchscript joiner model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the ONNX encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the ONNX decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the ONNX joiner model", + ) + + return parser + + +def test_encoder( + torch_encoder_model: torch.jit.ScriptModule, + torch_encoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + N = torch.randint(1, 100, size=(1,)).item() + T = onnx_model.segment + C = 80 + x_lens = torch.tensor([T] * N) + torch_states = [torch_encoder_model.get_init_state() for _ in range(N)] + torch_states = stack_states(torch_states) + + onnx_model.init_encoder_states(N) + + for i in range(5): + logging.info(f"test_encoder: iter {i}") + x = torch.rand(N, T, C) + torch_encoder_out, _, torch_states = torch_encoder_model( + x, x_lens, torch_states + ) + torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) + + onnx_encoder_out = onnx_model.run_encoder(x) + + assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( + (torch_encoder_out - onnx_encoder_out).abs().max() + ) + + +def test_decoder( + torch_decoder_model: torch.jit.ScriptModule, + torch_decoder_proj_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + context_size = onnx_model.context_size + vocab_size = onnx_model.vocab_size + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_decoder: iter {i}, N={N}") + x = torch.randint( + low=1, + high=vocab_size, + size=(N, context_size), + dtype=torch.int64, + ) + torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) + torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) + torch_decoder_out = torch_decoder_out.squeeze(1) + + onnx_decoder_out = onnx_model.run_decoder(x) + assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( + (torch_decoder_out - onnx_decoder_out).abs().max() + ) + + +def test_joiner( + torch_joiner_model: torch.jit.ScriptModule, + onnx_model: OnnxModel, +): + encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] + decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] + for i in range(10): + N = torch.randint(1, 100, size=(1,)).item() + logging.info(f"test_joiner: iter {i}, N={N}") + encoder_out = torch.rand(N, encoder_dim) + decoder_out = torch.rand(N, decoder_dim) + + projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) + projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) + + torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) + onnx_joiner_out = onnx_model.run_joiner( + projected_encoder_out, projected_decoder_out + ) + + assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( + (torch_joiner_out - onnx_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_encoder_model = torch.jit.load(args.jit_encoder_filename) + torch_decoder_model = torch.jit.load(args.jit_decoder_filename) + torch_joiner_model = torch.jit.load(args.jit_joiner_filename) + + onnx_model = OnnxModel( + encoder_model_filename=args.onnx_encoder_filename, + decoder_model_filename=args.onnx_decoder_filename, + joiner_model_filename=args.onnx_joiner_filename, + ) + + logging.info("Test encoder") + # When exporting the model to onnx, we have already put the encoder_proj + # inside the encoder. + test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) + + logging.info("Test decoder") + # When exporting the model to onnx, we have already put the decoder_proj + # inside the decoder. + test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) + + logging.info("Test joiner") + test_joiner(torch_joiner_model, onnx_model) + + logging.info("Finished checking ONNX models") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + torch.manual_seed(20230207) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 100644 index 000000000..71a418742 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1,231 @@ +from typing import Optional, Tuple + +import torch + + +class OnnxStreamingEncoder(torch.nn.Module): + """This class warps the streaming Zipformer to reduce the number of + state tensors for onnx. + https://github.com/k2-fsa/icefall/pull/831 + """ + + def __init__(self, encoder): + """ + Args: + encoder: An instance of Zipformer Class + """ + super().__init__() + self.model = encoder + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + len_cache: torch.tensor, + avg_cache: torch.tensor, + attn_cache: torch.tensor, + cnn_cache: torch.tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + len_cache: + The cached numbers of past frames. + avg_cache: + The cached average tensors. + attn_cache: + The cached key tensors of the first attention modules. + The cached value tensors of the first attention modules. + The cached value tensors of the second attention modules. + cnn_cache: + The cached left contexts of the first convolution modules. + The cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 2 tensors: + + """ + num_encoder_layers = [] + encoder_attention_dims = [] + states = [] + for i, encoder in enumerate(self.model.encoders): + num_encoder_layers.append(encoder.num_layers) + encoder_attention_dims.append(encoder.attention_dim) + + len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B] + offset = 0 + for num_layer in num_encoder_layers: + states.append(len_cache[offset : offset + num_layer]) + offset += num_layer + + avg_cache = avg_cache.transpose(0, 1) # [15, B, 384] + offset = 0 + for num_layer in num_encoder_layers: + states.append(avg_cache[offset : offset + num_layer]) + offset += num_layer + + attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192] + left_context_len = attn_cache.shape[1] + offset = 0 + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[offset : offset + num_layer, : left_context_len // ds] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + encoder_attention_dim = encoder_attention_dims[i] + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + for i, num_layer in enumerate(num_encoder_layers): + ds = self.model.zipformer_downsampling_factors[i] + states.append( + attn_cache[ + offset : offset + num_layer, + : left_context_len // ds, + :, + : encoder_attention_dim // 2, + ] + ) + offset += num_layer + + cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1] + offset = 0 + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + for num_layer in num_encoder_layers: + states.append(cnn_cache[offset : offset + num_layer]) + offset += num_layer + + encoder_out, encoder_out_lens, new_states = self.model.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose( + 0, 1 + ) # [B,15] + new_avg_cache = torch.cat( + states[self.model.num_encoders : 2 * self.model.num_encoders] + ).transpose( + 0, 1 + ) # [B,15,384] + new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose( + 0, 1 + ) # [B,2*15,384,cnn_kernel-1] + assert len(set(encoder_attention_dims)) == 1 + pad_tensors = [ + torch.nn.functional.pad( + tensor, + ( + 0, + encoder_attention_dims[0] - tensor.shape[-1], + 0, + 0, + 0, + left_context_len - tensor.shape[1], + 0, + 0, + ), + ) + for tensor in states[ + 2 * self.model.num_encoders : 5 * self.model.num_encoders + ] + ] + new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] + + return ( + encoder_out, + encoder_out_lens, + new_len_cache, + new_avg_cache, + new_attn_cache, + new_cnn_cache, + ) + + +class TritonOnnxDecoder(torch.nn.Module): + """This class warps the Decoder in decoder.py + to remove the scalar input "need_pad". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + """ + + def __init__( + self, + decoder: torch.nn.Module, + ): + """ + Args: + decoder: A instance of Decoder + """ + super().__init__() + self.model = decoder + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + # False to not pad the input. Should be False during inference. + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(torch.nn.Module): + """This class warps the Joiner in joiner.py + to remove the scalar input "project_input". + Triton currently doesn't support scalar input. + https://github.com/triton-inference-server/server/issues/2333 + "project_input" is set to True. + Triton solutions only need export joiner to a single joiner.onnx. + """ + + def __init__( + self, + joiner: torch.nn.Module, + ): + super().__init__() + self.model = joiner + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + # Apply input projections encoder_proj and decoder_proj. + project_input = True + return self.model(encoder_out, decoder_out, project_input) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 100755 index 000000000..8192e01fd --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1,512 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script loads ONNX models exported by ./export-onnx.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --decode-chunk-len 32 \ + --exp-dir $repo/exp/ + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +3. Run this file with the exported ONNX models + +./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + encoder_model_filename: str, + decoder_model_filename: str, + joiner_model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_encoder(encoder_model_filename) + self.init_decoder(decoder_model_filename) + self.init_joiner(joiner_model_filename) + + def init_encoder(self, encoder_model_filename: str): + self.encoder = ort.InferenceSession( + encoder_model_filename, + sess_options=self.session_opts, + ) + self.init_encoder_states() + + def init_encoder_states(self, batch_size: int = 1): + encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + + model_type = encoder_meta["model_type"] + assert model_type == "zipformer", model_type + + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + T = int(encoder_meta["T"]) + + num_encoder_layers = encoder_meta["num_encoder_layers"] + encoder_dims = encoder_meta["encoder_dims"] + attention_dims = encoder_meta["attention_dims"] + cnn_module_kernels = encoder_meta["cnn_module_kernels"] + left_context_len = encoder_meta["left_context_len"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + attention_dims = to_int_list(attention_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"attention_dims: {attention_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + + num_encoders = len(num_encoder_layers) + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + N = batch_size + + for i in range(num_encoders): + cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64)) + cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i])) + cached_key.append( + torch.zeros( + num_encoder_layers[i], left_context_len[i], N, attention_dims[i] + ) + ) + cached_val.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_val2.append( + torch.zeros( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ) + ) + cached_conv1.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + cached_conv2.append( + torch.zeros( + num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 + ) + ) + + self.cached_len = cached_len + self.cached_avg = cached_avg + self.cached_key = cached_key + self.cached_val = cached_val + self.cached_val2 = cached_val2 + self.cached_conv1 = cached_conv1 + self.cached_conv2 = cached_conv2 + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def init_decoder(self, decoder_model_filename: str): + self.decoder = ort.InferenceSession( + decoder_model_filename, + sess_options=self.session_opts, + ) + + decoder_meta = self.decoder.get_modelmeta().custom_metadata_map + self.context_size = int(decoder_meta["context_size"]) + self.vocab_size = int(decoder_meta["vocab_size"]) + + logging.info(f"context_size: {self.context_size}") + logging.info(f"vocab_size: {self.vocab_size}") + + def init_joiner(self, joiner_model_filename: str): + self.joiner = ort.InferenceSession( + joiner_model_filename, + sess_options=self.session_opts, + ) + + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map + self.joiner_dim = int(joiner_meta["joiner_dim"]) + + logging.info(f"joiner_dim: {self.joiner_dim}") + + def _build_encoder_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + encoder_input = {"x": x.numpy()} + encoder_output = ["encoder_out"] + + def build_states_input(states: List[torch.Tensor], name: str): + for i, s in enumerate(states): + if isinstance(s, torch.Tensor): + encoder_input[f"{name}_{i}"] = s.numpy() + else: + encoder_input[f"{name}_{i}"] = s + + encoder_output.append(f"new_{name}_{i}") + + build_states_input(self.cached_len, "cached_len") + build_states_input(self.cached_avg, "cached_avg") + build_states_input(self.cached_key, "cached_key") + build_states_input(self.cached_val, "cached_val") + build_states_input(self.cached_val2, "cached_val2") + build_states_input(self.cached_conv1, "cached_conv1") + build_states_input(self.cached_conv2, "cached_conv2") + + return encoder_input, encoder_output + + def _update_states(self, states: List[np.ndarray]): + num_encoders = self.num_encoders + + self.cached_len = states[num_encoders * 0 : num_encoders * 1] + self.cached_avg = states[num_encoders * 1 : num_encoders * 2] + self.cached_key = states[num_encoders * 2 : num_encoders * 3] + self.cached_val = states[num_encoders * 3 : num_encoders * 4] + self.cached_val2 = states[num_encoders * 4 : num_encoders * 5] + self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6] + self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7] + + def run_encoder(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor of shape (N, T', joiner_dim) where + T' is usually equal to ((T-7)//2+1)//2 + """ + encoder_input, encoder_output_names = self._build_encoder_input_output(x) + out = self.encoder.run(encoder_output_names, encoder_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A 2-D tensor of shape (N, context_size) + Returns: + Return a 2-D tensor of shape (N, joiner_dim) + """ + out = self.decoder.run( + [self.decoder.get_outputs()[0].name], + {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, + )[0] + + return torch.from_numpy(out) + + def run_joiner( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + out = self.joiner.run( + [self.joiner.get_outputs()[0].name], + { + self.joiner.get_inputs()[0].name: encoder_out.numpy(), + self.joiner.get_inputs()[1].name: decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(out) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: OnnxModel, + encoder_out: torch.Tensor, + context_size: int, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +) -> List[int]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (1, T, joiner_dim) + context_size: + The context size of the decoder model. + decoder_out: + Optional. Decoder output of the previous chunk. + hyp: + Decoding results for previous chunks. + Returns: + Return the decoded results so far. + """ + + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor([hyp], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + else: + assert hyp is not None, hyp + + encoder_out = encoder_out.squeeze(0) + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t : t + 1] + joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + context_size = model.context_size + hyp = None + decoder_out = None + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + encoder_out = model.run_encoder(frames) + hyp, decoder_out = greedy_search( + model, + encoder_out, + context_size, + decoder_out, + hyp, + ) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py new file mode 100755 index 000000000..5a36b695f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -0,0 +1,419 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/tokens.txt \ + --encoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/joiner_jit_trace-pnnx.ncnn.bin \ + ./sherpa-ncnn-streaming-zipformer-en-2023-02-13/test_wavs/1089-134686-0001.wav + +You can find pretrained models at +- English: https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13 +- Bilingual (Chinese + English): https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13 +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import k2 +import ncnn +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + # Please change the parameters according to your model + self.num_encoder_layers = to_int_tuple("2,4,3,2,4") + self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model + self.attention_dims = to_int_tuple("192,192,192,192,192") + self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2") + self.cnn_module_kernels = to_int_tuple("31,31,31,31,31") + + self.decode_chunk_size = 32 // 2 + num_left_chunks = 4 + self.left_context_length = self.decode_chunk_size * num_left_chunks # 64 + + self.chunk_length = self.decode_chunk_size * 2 + pad_length = 7 + self.T = self.chunk_length + pad_length + + def get_init_states(self) -> List[torch.Tensor]: + cached_len_list = [] + cached_avg_list = [] + cached_key_list = [] + cached_val_list = [] + cached_val2_list = [] + cached_conv1_list = [] + cached_conv2_list = [] + + for i in range(len(self.num_encoder_layers)): + num_layers = self.num_encoder_layers[i] + ds = self.zipformer_downsampling_factors[i] + attention_dim = self.attention_dims[i] + left_context_length = self.left_context_length // ds + encoder_dim = self.encoder_dims[i] + cnn_module_kernel = self.cnn_module_kernels[i] + + cached_len_list.append(torch.zeros(num_layers)) + cached_avg_list.append(torch.zeros(num_layers, encoder_dim)) + cached_key_list.append( + torch.zeros(num_layers, left_context_length, attention_dim) + ) + cached_val_list.append( + torch.zeros(num_layers, left_context_length, attention_dim // 2) + ) + cached_val2_list.append( + torch.zeros(num_layers, left_context_length, attention_dim // 2) + ) + cached_conv1_list.append( + torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) + ) + cached_conv2_list.append( + torch.zeros(num_layers, encoder_dim, cnn_module_kernel - 1) + ) + + states = ( + cached_len_list + + cached_avg_list + + cached_key_list + + cached_val_list + + cached_val2_list + + cached_conv1_list + + cached_conv2_list + ) + + return states + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_net.opt.num_threads = 4 + + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.num_threads = 4 + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.num_threads = 4 + + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + A tensor of shape (T, C) + states: + A list of tensors. len(states) == self.num_layers * 4 + Returns: + Return a tuple containing: + - encoder_out, a tensor of shape (T, encoder_dim). + - next_states, a list of tensors containing the next states + """ + with self.encoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + for i in range(len(states)): + name = f"in{i+1}" + ex.input(name, ncnn.Mat(states[i].squeeze().numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + + out_states: List[torch.Tensor] = [] + for i in range(len(states)): + name = f"out{i+1}" + ret, ncnn_out_state = ex.extract(name) + assert ret == 0, ret + ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy()) + + if i < len(self.num_encoder_layers): + # for cached_len, we need to discard the last dim + ncnn_out_state = ncnn_out_state.squeeze(1) + + out_states.append(ncnn_out_state) + + return encoder_out, out_states + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t] + + joiner_out = model.run_joiner(cur_encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + states = model.get_init_states() + logging.info(f"number of states: {len(states)}") + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = model.T + offset = model.chunk_length + + chunk = int(1 * sample_rate) # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, states = model.run_encoder(frames, states) + hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(sound_file) + logging.info(text) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 7a349ecb2..c272ed641 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -409,18 +409,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -430,9 +426,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 2bdc882a5..c7a2a136d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1049,10 +1049,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py new file mode 100755 index 000000000..5437e961e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -0,0 +1,1265 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --full-libri 1 \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer2 import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index 88beb38c1..a5c422959 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -44,7 +44,6 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.dist import get_rank from icefall.utils import make_pad_mask, subsequent_chunk_mask @@ -270,8 +269,7 @@ class Zipformer(EncoderInterface): dim_feedforward (int, int): feedforward dimension in 2 encoder stacks num_encoder_layers (int): number of encoder layers dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. + cnn_module_kernels (int): Kernel size of convolution module warmup_batches (float): number of batches to warm up over """ @@ -311,6 +309,8 @@ class Zipformer(EncoderInterface): # Used in decoding self.decode_chunk_size = decode_chunk_size + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + # will be written to, see set_batch_count() self.batch_count = 0 self.warmup_end = warmup_batches @@ -330,7 +330,10 @@ class Zipformer(EncoderInterface): # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] + self.num_encoder_layers = num_encoder_layers self.num_encoders = len(encoder_dims) + self.attention_dims = attention_dim + self.cnn_module_kernels = cnn_module_kernels for i in range(self.num_encoders): encoder_layer = ZipformerEncoderLayer( encoder_dims[i], @@ -382,10 +385,10 @@ class Zipformer(EncoderInterface): def _init_skip_modules(self): """ - If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, - we combine the outputs of layers 1 and 5. + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, + we combine the outputs of layers 1 and 4. """ skip_layers = [] skip_modules = [] @@ -421,13 +424,13 @@ class Zipformer(EncoderInterface): """ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than + On e.g. 15% of frames, these masks will zero out all encoder dims larger than some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. + a smaller encoder dim. We generate the random masks at this level because we want the 2 masks to 'agree' all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. + mask values repeated self.zipformer_downsampling_factors times. Args: x: the embeddings (needed for the shape and dtype and device), of shape @@ -695,7 +698,7 @@ class Zipformer(EncoderInterface): num_layers = encoder.num_layers ds = self.zipformer_downsampling_factors[i] - len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device) + len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) cached_len.append(len_avg) avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) @@ -1267,8 +1270,7 @@ class ZipformerEncoder(nn.Module): Shape: src: (S, N, E). - cached_len: (N,) - N is the batch size. + cached_len: (num_layers,) cached_avg: (num_layers, N, C). N is the batch size, C is the feature dimension. cached_key: (num_layers, left_context_len, N, K). @@ -1284,8 +1286,8 @@ class ZipformerEncoder(nn.Module): Returns: A tuple of 8 tensors: - output tensor - - updated cached number of past frmaes. - - updated cached average of past frmaes. + - updated cached number of past frames. + - updated cached average of past frames. - updated cached key tensor of of the first attention module. - updated cached value tensor of of the first attention module. - updated cached value tensor of of the second attention module. @@ -1517,9 +1519,6 @@ class AttentionDownsample(torch.nn.Module): """ def __init__(self, in_channels: int, out_channels: int, downsample: int): - """ - Require out_channels > in_channels. - """ super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) @@ -1687,8 +1686,8 @@ class RelPositionalEncoding(torch.nn.Module): if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i List[Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + ``states[i][0:num_encoders]`` is the cached numbers of past frames. + ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + assert len(state_list[0]) % 7 == 0, len(state_list[0]) + num_encoders = len(state_list[0]) // 7 + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + # For cached_len + len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] + for i in range(num_encoders): + # len_avg: (num_layers, batch_size) + len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) + cached_len.append(len_avg) + + # For cached_avg + avg_list = [ + state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # avg: (num_layers, batch_size, D) + avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) + cached_avg.append(avg) + + # For cached_key + key_list = [ + state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # key: (num_layers, left_context_size, batch_size, D) + key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) + cached_key.append(key) + + # For cached_val + val_list = [ + state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val: (num_layers, left_context_size, batch_size, D) + val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) + cached_val.append(val) + + # For cached_val2 + val2_list = [ + state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # val2: (num_layers, left_context_size, batch_size, D) + val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) + cached_val2.append(val2) + + # For cached_conv1 + conv1_list = [ + state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv1: (num_layers, batch_size, D, kernel-1) + conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) + cached_conv1.append(conv1) + + # For cached_conv2 + conv2_list = [ + state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) + ] + for i in range(num_encoders): + # conv2: (num_layers, batch_size, D, kernel-1) + conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + A list of states. + ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. + """ + assert len(states) % 7 == 0, len(states) + num_encoders = len(states) // 7 + ( + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) + + batch_size = cached_len[0].shape[1] + + len_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_len[i]: (num_layers, batch_size) + len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + len_list[n].append(len_avg[n]) + + avg_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_avg[i]: (num_layers, batch_size, D) + avg = cached_avg[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + avg_list[n].append(avg[n]) + + key_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_key[i]: (num_layers, left_context, batch_size, D) + key = cached_key[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + key_list[n].append(key[n]) + + val_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val[i]: (num_layers, left_context, batch_size, D) + val = cached_val[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val_list[n].append(val[n]) + + val2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_val2[i]: (num_layers, left_context, batch_size, D) + val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) + for n in range(batch_size): + val2_list[n].append(val2[n]) + + conv1_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) + conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv1_list[n].append(conv1[n]) + + conv2_list = [[] for _ in range(batch_size)] + for i in range(num_encoders): + # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) + conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) + for n in range(batch_size): + conv2_list[n].append(conv2[n]) + + state_list = [ + ( + len_list[i] + + avg_list[i] + + key_list[i] + + val_list[i] + + val2_list[i] + + conv1_list[i] + + conv2_list[i] + ) + for i in range(batch_size) + ] + return state_list + + +class Zipformer(EncoderInterface): + """ + Args: + num_features (int): Number of input features + d_model: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernels (int): Kernel size of convolution module + warmup_batches (float): number of batches to warm up over + is_pnnx (bool): True if we are going to convert this model via pnnx. + """ + + def __init__( + self, + num_features: int, + output_downsampling_factor: int = 2, + encoder_dims: Tuple[int] = (384, 384), + attention_dim: Tuple[int] = (256, 256), + encoder_unmasked_dims: Tuple[int] = (256, 256), + zipformer_downsampling_factors: Tuple[int] = (2, 4), + nhead: Tuple[int] = (8, 8), + feedforward_dim: Tuple[int] = (1536, 2048), + num_encoder_layers: Tuple[int] = (12, 12), + dropout: float = 0.1, + cnn_module_kernels: Tuple[int] = (31, 31), + pos_dim: int = 4, + num_left_chunks: int = 4, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 50, + decode_chunk_size: int = 16, + warmup_batches: float = 4000.0, + is_pnnx: bool = False, + ) -> None: + super(Zipformer, self).__init__() + self.is_pnnx = is_pnnx + + self.num_features = num_features + assert 0 < encoder_dims[0] <= encoder_dims[1] + self.encoder_dims = encoder_dims + self.encoder_unmasked_dims = encoder_unmasked_dims + self.zipformer_downsampling_factors = zipformer_downsampling_factors + self.output_downsampling_factor = output_downsampling_factor + + self.num_left_chunks = num_left_chunks + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + + # Used in decoding + self.decode_chunk_size = decode_chunk_size + + self.left_context_len = self.decode_chunk_size * self.num_left_chunks + + # will be written to, see set_batch_count() + self.batch_count = 0 + self.warmup_end = warmup_batches + + for u, d in zip(encoder_unmasked_dims, encoder_dims): + assert u <= d, (u, d) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7)//2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7)//2 + # (2) embedding: num_features -> encoder_dims + self.encoder_embed = Conv2dSubsampling( + num_features, encoder_dims[0], dropout=dropout, is_pnnx=is_pnnx + ) + + # each one will be ZipformerEncoder or DownsampledZipformerEncoder + encoders = [] + + self.num_encoders = len(encoder_dims) + for i in range(self.num_encoders): + ds = zipformer_downsampling_factors[i] + encoder_layer = ZipformerEncoderLayer( + encoder_dims[i], + attention_dim[i], + nhead[i], + feedforward_dim[i], + dropout, + cnn_module_kernels[i], + pos_dim, + is_pnnx=self.is_pnnx, + left_context_len=self.left_context_len // ds, + x_size=self.decode_chunk_size // ds, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = ZipformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout, + warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), + is_pnnx=is_pnnx, + left_context_len=self.left_context_len // ds, + x_size=self.decode_chunk_size // ds, + ) + + if zipformer_downsampling_factors[i] != 1: + in_x_size = self.decode_chunk_size + encoder = DownsampledZipformerEncoder( + encoder, + input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + output_dim=encoder_dims[i], + downsample=zipformer_downsampling_factors[i], + is_pnnx=is_pnnx, + left_context_len=self.left_context_len // ds, + in_x_size=in_x_size, + ) + encoders.append(encoder) + self.encoders = nn.ModuleList(encoders) + + # initializes self.skip_layers and self.skip_modules + self._init_skip_modules() + + self.downsample_output = AttentionDownsample( + encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor, + is_pnnx=is_pnnx, + in_x_size=self.decode_chunk_size, + ) + + def _get_layer_skip_dropout_prob(self): + if not self.training: + return 0.0 + batch_count = self.batch_count + min_dropout_prob = 0.025 + + if batch_count > self.warmup_end: + return min_dropout_prob + else: + return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) + + def _init_skip_modules(self): + """ + If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer + indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of + layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, + we combine the outputs of layers 1 and 4. + """ + skip_layers = [] + skip_modules = [] + z = self.zipformer_downsampling_factors + for i in range(len(z)): + if i <= 1 or z[i - 1] <= z[i]: + skip_layers.append(None) + skip_modules.append(SimpleCombinerIdentity()) + else: + # TEMP + for j in range(i - 2, -1, -1): + if z[j] <= z[i] or j == 0: + # TEMP logging statement. + logging.info( + f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." + ) + skip_layers.append(j) + skip_modules.append( + SimpleCombiner( + self.encoder_dims[j], + self.encoder_dims[i - 1], + min_weight=(0.0, 0.25), + ) + ) + break + self.skip_layers = skip_layers + self.skip_modules = nn.ModuleList(skip_modules) + + def get_feature_masks(self, x: torch.Tensor) -> List[float]: + # Note: The actual return type is Union[List[float], List[Tensor]], + # but to make torch.jit.script() work, we use List[float] + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all encoder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoder dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_downsampling_factors times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (num_frames, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dims) + if torch.jit.is_scripting() or not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dims[0] == _encoder_dims0, ( + self.encoder_dims, + _encoder_dims0, + ) + + max_downsampling_factor = max(self.zipformer_downsampling_factors) + + num_frames_max = num_frames0 + max_downsampling_factor - 1 + + feature_mask_dropout_prob = 0.15 + + # frame_mask_max shape: (num_frames_max, batch_size, 1) + frame_mask_max = ( + torch.rand(num_frames_max, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + feature_masks = [] + for i in range(num_encoders): + ds = self.zipformer_downsampling_factors[i] + upsample_factor = max_downsampling_factor // ds + + frame_mask = ( + frame_mask_max.unsqueeze(1) + .expand(num_frames_max, upsample_factor, batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1) + ) + num_frames = (num_frames0 + ds - 1) // ds + frame_mask = frame_mask[:num_frames] + feature_mask = torch.ones( + num_frames, + batch_size, + self.encoder_dims[i], + dtype=x.dtype, + device=x.device, + ) + u = self.encoder_unmasked_dims[i] + feature_mask[:, :, u:] *= frame_mask + feature_masks.append(feature_mask) + + return feature_masks + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + chunk_size: + The chunk size used in evaluation mode. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + lengths = (x_lens - 7) >> 1 + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) + mask = make_pad_mask(lengths) + + outputs = [] + feature_masks = self.get_feature_masks(x) + + if self.training: + # Training mode + max_ds = max(self.zipformer_downsampling_factors) + # Generate dynamic chunk-wise attention mask during training + max_len = x.size(0) // max_ds + short_chunk_size = self.short_chunk_size // max_ds + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * self.short_chunk_threshold): + # Full attention + chunk_size = x.size(0) + else: + # Chunk-wise attention + chunk_size = chunk_size % short_chunk_size + 1 + chunk_size *= max_ds + else: + chunk_size = self.decode_chunk_size + # Evaluation mode + for ds in self.zipformer_downsampling_factors: + assert chunk_size % ds == 0, (chunk_size, ds) + + attn_mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=self.num_left_chunks, + device=x.device, + ) + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + ds = self.zipformer_downsampling_factors[i] + k = self.skip_layers[i] + if isinstance(k, int): + layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() + if torch.jit.is_scripting(): + x = skip_module(outputs[k], x) + elif (not self.training) or random.random() > layer_skip_dropout_prob: + x = skip_module(outputs[k], x) + x = module( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[..., ::ds], + attn_mask=attn_mask[::ds, ::ds], + ) + outputs.append(x) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + lengths = (lengths + 1) >> 1 + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths + + def streaming_forward( + self, + x: torch.Tensor, + states: List[Tensor], + ) -> Tuple[Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + seq_len is the input chunk length. + states: + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + + Returns: + Return a tuple containing 3 tensors: + - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states. + """ + assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) + + cached_len = states[: self.num_encoders] + cached_avg = states[self.num_encoders : 2 * self.num_encoders] + cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] + cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] + cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] + cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] + cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] + + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + outputs = [] + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + + for i, (module, skip_module) in enumerate( + zip(self.encoders, self.skip_modules) + ): + k = self.skip_layers[i] + if isinstance(k, int): + x = skip_module(outputs[k], x) + x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( + x, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=cached_conv1[i], + cached_conv2=cached_conv2[i], + ) + + outputs.append(x) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = ( + new_cached_len + + new_cached_avg + + new_cached_key + + new_cached_val + + new_cached_val2 + + new_cached_conv1 + + new_cached_conv2 + ) + return x, new_states + + @torch.jit.export + def get_init_state( + self, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + A list of 7 * num_encoders elements: + ``states[0:num_encoders]`` is the cached numbers of past frames. + ``states[num_encoders:2*num_encoders]`` is the cached average tensors. + ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. + ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. + ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. + ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. + ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. + """ + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + for i, encoder in enumerate(self.encoders): + num_layers = encoder.num_layers + ds = self.zipformer_downsampling_factors[i] + + len_avg = torch.zeros(num_layers, 1, device=device) + cached_len.append(len_avg) + + avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) + cached_avg.append(avg) + + key = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim, + device=device, + ) + cached_key.append(key) + + val = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val.append(val) + + val2 = torch.zeros( + num_layers, + self.left_context_len // ds, + 1, + encoder.attention_dim // 2, + device=device, + ) + cached_val2.append(val2) + + conv1 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv1.append(conv1) + + conv2 = torch.zeros( + num_layers, + 1, + encoder.d_model, + encoder.cnn_module_kernel - 1, + device=device, + ) + cached_conv2.append(conv2) + + states = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + return states + + +class ZipformerEncoderLayer(nn.Module): + """ + ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, + is_pnnx: bool = False, + left_context_len: int = 0, + x_size: int = 0, + ) -> None: + super(ZipformerEncoderLayer, self).__init__() + + self.d_model = d_model + self.attention_dim = attention_dim + self.cnn_module_kernel = cnn_module_kernel + + # will be written to, see set_batch_count() + self.batch_count = 0 + + self.self_attn = RelPositionMultiheadAttention( + d_model, + attention_dim, + nhead, + pos_dim, + dropout=0.0, + is_pnnx=is_pnnx, + left_context_len=left_context_len, + x_size=x_size, + ) + + self.pooling = PoolingModule(d_model) + + self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + + self.conv_module1 = ConvolutionModule( + d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size + ) + + self.conv_module2 = ConvolutionModule( + d_model, cnn_module_kernel, is_pnnx=is_pnnx, x_size=x_size + ) + + self.norm_final = BasicNorm(d_model) + + self.bypass_scale = nn.Parameter(torch.tensor(0.5)) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.whiten = Whiten( + num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 + ) + + def get_bypass_scale(self): + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + if random.random() < 0.1: + # ensure we get grads if self.bypass_scale becomes out of range + return self.bypass_scale + # hardcode warmup period for bypass scale + warmup_period = 20000.0 + initial_clamp_min = 0.75 + final_clamp_min = 0.25 + if self.batch_count > warmup_period: + clamp_min = final_clamp_min + else: + clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( + initial_clamp_min - final_clamp_min + ) + return self.bypass_scale.clamp(min=clamp_min, max=1.0) + + def get_dynamic_dropout_rate(self): + # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this + # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable + # at the beginning, by making the network focus on the feedforward modules. + if torch.jit.is_scripting() or not self.training: + return 0.0 + warmup_period = 2000.0 + initial_dropout_rate = 0.2 + final_dropout_rate = 0.0 + if self.batch_count > warmup_period: + return final_dropout_rate + else: + return initial_dropout_rate - ( + initial_dropout_rate * final_dropout_rate + ) * (self.batch_count / warmup_period) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + batch_split: if not None, this layer will only be applied to + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + # dropout rate for submodules that interact with time. + dynamic_dropout = self.get_dynamic_dropout_rate() + + # pooling module + if torch.jit.is_scripting(): + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + elif random.random() >= dynamic_dropout: + src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) + + if torch.jit.is_scripting(): + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + src = src + self.self_attn.forward2(src, attn_weights) + + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + else: + use_self_attn = random.random() >= dynamic_dropout + if use_self_attn: + src_att, attn_weights = self.self_attn( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + src = src + src_att + + if random.random() >= dynamic_dropout: + src = src + self.conv_module1( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward2(src) + + if use_self_attn: + src = src + self.self_attn.forward2(src, attn_weights) + + if random.random() >= dynamic_dropout: + src = src + self.conv_module2( + src, src_key_padding_mask=src_key_padding_mask + ) + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.get_bypass_scale() + + return self.whiten(src) + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + cached_len: processed number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor of left context for the first attention module. + cached_val: cached value tensor of left context for the first attention module. + cached_val2: cached value tensor of left context for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + pos_emb: (N, left_context_len+2*S-1, E) + cached_len: (N,) + N is the batch size. + cached_avg: (N, C). + N is the batch size, C is the feature dimension. + cached_key: (left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + """ + src_orig = src + + # macaron style feed forward module + src = src + self.feed_forward1(src) + + src_pool, cached_len, cached_avg = self.pooling.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + ) + src = src + src_pool + + ( + src_attn, + attn_weights, + cached_key, + cached_val, + ) = self.self_attn.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + cached_val=cached_val, + ) + + src = src + src_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + ) + + src = src + src_conv + + src = src + self.feed_forward2(src) + + src_attn, cached_val2 = self.self_attn.streaming_forward2( + src, + attn_weights, + cached_val=cached_val2, + ) + src = src + src_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm_final(self.balancer(src)) + + delta = src - src_orig + + src = src_orig + delta * self.bypass_scale + + return ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class ZipformerStateSelect(nn.Module): + """ncnn does not support selecting along batch index. + This class provides a workaround for it. We + need to change pnnx accordingly. + """ + + def __init__(self, i: int): + super().__init__() + self.i = i + + def forward(self, x: torch.Tensor): + return x[self.i] + + +class ZipformerEncoder(nn.Module): + r"""ZipformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) + >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + is_pnnx: bool = False, + x_size: int = 0, + left_context_len: int = 0, + ) -> None: + super().__init__() + # will be written to, see set_batch_count() Note: in inference time this + # may be zero but should be treated as large, we can check if + # self.training is true. + self.batch_count = 0 + self.warmup_begin = warmup_begin + self.warmup_end = warmup_end + # module_seed is for when we need a random number that is unique to the module but + # shared across jobs. It's used to randomly select how many layers to drop, + # so that we can keep this consistent across worker tasks (for efficiency). + self.module_seed = torch.randint(0, 1000, ()).item() + self.left_context_len = left_context_len + + self.encoder_pos = RelPositionalEncoding( + encoder_layer.d_model, + dropout, + is_pnnx=is_pnnx, + x_size=x_size, + left_context_len=left_context_len, + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + state_select_list = [] + for i in range(num_layers): + state_select_list.append(ZipformerStateSelect(i)) + self.state_select_list = nn.ModuleList(state_select_list) + + self.d_model = encoder_layer.d_model + self.attention_dim = encoder_layer.attention_dim + self.cnn_module_kernel = encoder_layer.cnn_module_kernel + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin + for i in range(num_layers): + self.layers[i].warmup_begin = cur_begin + cur_begin += delta + self.layers[i].warmup_end = cur_begin + + def get_layers_to_drop(self, rnd_seed: int): + ans = set() + if not self.training: + return ans + + batch_count = self.batch_count + num_layers = len(self.layers) + + def get_layerdrop_prob(layer: int) -> float: + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + + initial_layerdrop_prob = 0.5 + final_layerdrop_prob = 0.05 + + if batch_count == 0: + # As a special case, if batch_count == 0, return 0 (drop no + # layers). This is rather ugly, I'm afraid; it is intended to + # enable our scan_pessimistic_batches_for_oom() code to work correctly + # so if we are going to get OOM it will happen early. + # also search for 'batch_count' with quotes in this file to see + # how we initialize the warmup count to a random number between + # 0 and 10. + return 0.0 + elif batch_count < layer_warmup_begin: + return initial_layerdrop_prob + elif batch_count > layer_warmup_end: + return final_layerdrop_prob + else: + # linearly interpolate + t = (batch_count - layer_warmup_begin) / layer_warmup_end + assert 0.0 <= t < 1.001, t + return initial_layerdrop_prob + t * ( + final_layerdrop_prob - initial_layerdrop_prob + ) + + shared_rng = random.Random(batch_count + self.module_seed) + independent_rng = random.Random(rnd_seed) + + layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + tot = sum(layerdrop_probs) + # Instead of drawing the samples independently, we first randomly decide + # how many layers to drop out, using the same random number generator between + # jobs so that all jobs drop out the same number (this is for speed). + # Then we use an approximate approach to drop out the individual layers + # with their specified probs while reaching this exact target. + num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) + + layers = list(range(num_layers)) + independent_rng.shuffle(layers) + + # go through the shuffled layers until we get the required number of samples. + if num_to_drop > 0: + for layer in itertools.cycle(layers): + if independent_rng.random() < layerdrop_probs[layer]: + ans.add(layer) + if len(ans) == num_to_drop: + break + if shared_rng.random() < 0.005 or __name__ == "__main__": + logging.info( + f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" + ) + return ans + + def forward( + self, + src: Tensor, + # Note: The type of feature_mask should be Union[float, Tensor], + # but to make torch.jit.script() work, we use `float` here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: (x, x_no_combine), both of shape (S, N, E) + """ + pos_emb = self.encoder_pos(src) + output = src + + if torch.jit.is_scripting(): + layers_to_drop = [] + else: + rnd_seed = src.numel() + random.randint(0, 1000) + layers_to_drop = self.get_layers_to_drop(rnd_seed) + + output = output * feature_mask + + for i, mod in enumerate(self.layers): + if not torch.jit.is_scripting(): + if i in layers_to_drop: + continue + output = mod( + output, + pos_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + output = output * feature_mask + + return output + + @torch.jit.export + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + cached_len: number of past frames. + cached_avg: cached average of past frames. + cached_key: cached key tensor for first attention module. + cached_val: cached value tensor for first attention module. + cached_val2: cached value tensor for second attention module. + cached_conv1: cached left contexts for the first convolution module. + cached_conv2: cached left contexts for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (num_layers,) + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + + Returns: A tuple of 8 tensors: + - output tensor + - updated cached number of past frames. + - updated cached average of past frames. + - updated cached key tensor of of the first attention module. + - updated cached value tensor of of the first attention module. + - updated cached value tensor of of the second attention module. + - updated cached left contexts of the first convolution module. + - updated cached left contexts of the second convolution module. + """ + assert cached_len.size(0) == self.num_layers, ( + cached_len.size(0), + self.num_layers, + ) + assert cached_avg.size(0) == self.num_layers, ( + cached_avg.size(0), + self.num_layers, + ) + assert cached_key.size(0) == self.num_layers, ( + cached_key.size(0), + self.num_layers, + ) + assert cached_val.size(0) == self.num_layers, ( + cached_val.size(0), + self.num_layers, + ) + assert cached_val2.size(0) == self.num_layers, ( + cached_val2.size(0), + self.num_layers, + ) + assert cached_conv1.size(0) == self.num_layers, ( + cached_conv1.size(0), + self.num_layers, + ) + assert cached_conv2.size(0) == self.num_layers, ( + cached_conv2.size(0), + self.num_layers, + ) + + assert self.left_context_len == cached_key.shape[1], ( + self.left_context_len, + cached_key.shape[1], + ) + + left_context_len = self.left_context_len + pos_emb = self.encoder_pos(src, left_context_len) + + output = src + + new_cached_len = [] + new_cached_avg = [] + new_cached_key = [] + new_cached_val = [] + new_cached_val2 = [] + new_cached_conv1 = [] + new_cached_conv2 = [] + for i, (mod, state_select) in enumerate( + zip(self.layers, self.state_select_list) + ): + output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( + output, + pos_emb, + cached_len=cached_len[i], + cached_avg=cached_avg[i], + cached_key=cached_key[i], + cached_val=cached_val[i], + cached_val2=cached_val2[i], + cached_conv1=state_select(cached_conv1), + cached_conv2=state_select(cached_conv2), + ) + # Update caches + new_cached_len.append(len_avg) + new_cached_avg.append(avg) + new_cached_key.append(key) + new_cached_val.append(val) + new_cached_val2.append(val2) + new_cached_conv1.append(conv1) + new_cached_conv2.append(conv2) + + return ( + output, + torch.stack(new_cached_len, dim=0), + torch.stack(new_cached_avg, dim=0), + torch.stack(new_cached_key, dim=0), + torch.stack(new_cached_val, dim=0), + torch.stack(new_cached_val2, dim=0), + torch.stack(new_cached_conv1, dim=0), + torch.stack(new_cached_conv2, dim=0), + ) + + +class DownsampledZipformerEncoder(nn.Module): + r""" + DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int, + is_pnnx: bool = False, + left_context_len: int = 0, + in_x_size: int = 0, + ): + super(DownsampledZipformerEncoder, self).__init__() + self.downsample_factor = downsample + self.downsample = AttentionDownsample( + input_dim, output_dim, downsample, is_pnnx=is_pnnx, in_x_size=in_x_size + ) + self.encoder = encoder + self.num_layers = encoder.num_layers + self.d_model = encoder.d_model + self.attention_dim = encoder.attention_dim + self.cnn_module_kernel = encoder.cnn_module_kernel + self.upsample = SimpleUpsample(output_dim, downsample) + self.out_combiner = SimpleCombiner( + input_dim, output_dim, min_weight=(0.0, 0.25) + ) + self.in_x_size = in_x_size + + def forward( + self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer. feature_mask is expected to be already downsampled by + self.downsample_factor. + attn_mask: attention mask (optional). Should be downsampled already. + src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. + + Shape: + src: (S, N, E). + attn_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + src_orig = src + src = self.downsample(src) + + src = self.encoder( + src, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + cached_len: Tensor, + cached_avg: Tensor, + cached_key: Tensor, + cached_val: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required). + cached_avg: cached average value of past frames. + cached_len: length of past frames. + cached_key: cached key tensor for the first attention module. + cached_val: cached value tensor for the first attention module. + cached_val2: cached value tensor for the second attention module. + cached_conv1: cached left context for the first convolution module. + cached_conv2: cached left context for the second convolution module. + + Shape: + src: (S, N, E). + cached_len: (N,) + N is the batch size. + cached_avg: (num_layers, N, C). + N is the batch size, C is the feature dimension. + cached_key: (num_layers, left_context_len, N, K). + N is the batch size, K is the key dimension. + cached_val: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_val2: (num_layers, left_context_len, N, V). + N is the batch size, V is the key dimension. + cached_conv1: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + cached_conv2: (num_layers, N, C, kernel_size-1). + N is the batch size, C is the convolution channels. + Returns: output of shape (S, N, F) where F is the number of output features + (output_dim to constructor) + """ + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) + + src_orig = src + + src = self.downsample(src) + + ( + src, + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) = self.encoder.streaming_forward( + src, + cached_len=cached_len, + cached_avg=cached_avg, + cached_key=cached_key, + cached_val=cached_val, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + ) + + src = self.upsample(src) + + if src.shape[0] != self.in_x_size: + # remove any extra frames that are not a multiple of downsample_factor + src = src[: self.in_x_size] + + return ( + self.out_combiner(src_orig, src), + cached_len, + cached_avg, + cached_key, + cached_val, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class AttentionDownsampleUnsqueeze(torch.nn.Module): + """We apply this operation only in PyTorch + and discards in ncnn. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(1) + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + downsample: int, + is_pnnx: bool = False, + in_x_size: int = 0, + ): + super(AttentionDownsample, self).__init__() + + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_pnnx = is_pnnx + self.in_x_size = in_x_size + + self.unsqueeze = AttentionDownsampleUnsqueeze() + + # fill in the extra dimensions with a projection of the input + if out_channels > in_channels: + self.extra_proj = nn.Linear( + in_channels * downsample, out_channels - in_channels, bias=False + ) + else: + self.extra_proj = None + self.downsample = downsample + + self.d_seq_len = (in_x_size + downsample - 1) // downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, 1, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + assert src.shape[0] == self.in_x_size, ( + src.shape[0], + self.in_x_size, + src.shape, + type(src), + ) + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) + if not self.is_pnnx: + (seq_len, batch_size, in_channels) = src.shape + else: + seq_len = self.in_x_size + batch_size = 1 + in_channels = self.in_channels + + ds = self.downsample + d_seq_len = self.d_seq_len + + # Pad to an exact multiple of self.downsample + if seq_len != d_seq_len * ds: + assert self.is_pnnx is False, "TODO(fangjun): Handle it!" + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) + + if not self.is_pnnx: + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape( + d_seq_len, batch_size, ds * in_channels + ) + + if self.extra_proj is not None: + ans2 = self.extra_proj(src) + ans = torch.cat((ans, ans2), dim=2) + else: + src = src.reshape(d_seq_len, ds, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + + weights = scores.softmax(dim=1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + assert ( + self.extra_proj is None + ), "The code for it being not None is not tested" + # ans = ans.unsqueeze(1) + ans = self.unsqueeze(ans) + # Note: In ncnn, we ignore self.unsqueeze + # so ans in ncnn is still a 2-D tensor, e.g., (8, 384) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + self.upsample = upsample + self.num_channels = num_channels + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class SimpleCombinerIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + return src1 + + +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + Args: + dim1: the dimension of the first input, e.g. 256 + dim2: the dimension of the second input, e.g. 384. + The output will have the same dimension as dim2. + """ + + def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + super(SimpleCombiner, self).__init__() + assert dim2 >= dim1, (dim2, dim1) + self.weight1 = nn.Parameter(torch.zeros(())) + self.min_weight = min_weight + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) + + weight1 = self.weight1 + if not torch.jit.is_scripting(): + if ( + self.training + and random.random() < 0.25 + and self.min_weight != (0.0, 0.0) + ): + weight1 = weight1.clamp( + min=self.min_weight[0], max=1.0 - self.min_weight[1] + ) + + src1 = src1 * weight1 + src2 = src2 * (1.0 - weight1) + + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) + + src1_dim = self.dim1 + src2_dim = self.dim2 + + if src1_dim != src2_dim: + if src1_dim < src2_dim: + src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) + else: + src1 = src1[:src2_dim] + + return src1 + src2 + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + is_pnnx: bool = False, + x_size: int = 0, + left_context_len: int = 0, + ) -> None: + """Construct a PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(dropout_rate) + self.is_pnnx = is_pnnx + self.x_size = x_size + self.left_context_len = left_context_len + self.pe = None + if is_pnnx: + x_size_left = x_size + left_context_len + self.extend_pe(torch.tensor(0.0).expand(x_size_left)) + self.pe = self.pe[:, :-left_context_len] + assert self.pe.size(1) == x_size + left_context_len - 1 + x_size, ( + self.pe.size(1), + x_size, + left_context_len, + x_size, + self.pe.shape, + ) + else: + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + x_size_left = x.size(0) + left_context_len + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_left * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tensor: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). + + """ + if self.is_pnnx: + assert self.x_size == x.size(0), (self.x_size, x.size(0)) + assert self.left_context_len == left_context_len, ( + self.left_context_len, + left_context_len, + ) + return self.pe + + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_left + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(0), + ] + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionPermute(nn.Module): + """ncnn does not support permuatation relating to the batch axis 0. + This is a workaround for exporting to ncnn via PNNX. + """ + + def __init__(self, kind: int): + super().__init__() + self.kind = kind + assert self.kind in (2, 3), self.kind + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.kind == 2: + return x.permute(1, 0, 2) + elif self.kind == 3: + return x.permute(1, 2, 0) + else: + assert False, f"Unsupported kind {self.kind}" + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + pos_dim: int, + dropout: float = 0.0, + is_pnnx: bool = False, + left_context_len: int = 0, + x_size: int = 0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + self.pos_dim = pos_dim + assert self.head_dim % 2 == 0, self.head_dim + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + self.is_pnnx = is_pnnx + + self.my_permute_pqv = RelPositionMultiheadAttentionPermute(kind=2) + self.my_permute_k_pos = RelPositionMultiheadAttentionPermute(kind=3) + self.left_context_len = left_context_len + self.x_size = x_size + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5, dividing it between the query and key. + in_proj_dim = ( + 2 * attention_dim + + attention_dim // 2 # query (attention_dim,), key (attention_dim,) + + pos_dim * num_heads # value (attention_dim // 2,) + ) # positional encoding query (pos_dim * num_heads, ) + + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 + ) + + # self.whiten_values is applied on the values in forward(); + # it just copies the keys but prevents low-rank distribution by modifying grads. + self.whiten_values = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option. + # they only copy their inputs. + self.copy_pos_query = Identity() + self.copy_query = Identity() + + self.out_proj = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear( + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 + ) + # self.whiten_values2 is applied on the values in forward2() + self.whiten_values2 = Whiten( + num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Returns: (attn_output, attn_weights) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + """ + x, weights = self.multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return x, weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x: input to be projected to query, key, value + pos_emb: Positional embedding tensor + + Shape: + - Inputs: + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. + - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. + + - Returns: (attn_output, attn_weights, cached_key, cached_val) + + - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads + and S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of + left context + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of + """ + ( + x, + weights, + cached_key, + cached_val, + ) = self.streaming_multi_head_attention_forward( + self.in_proj(x), + self.linear_pos(pos_emb), + self.attention_dim, + self.num_heads, + self.out_proj.weight, + self.out_proj.bias, + cached_key=cached_key, + cached_val=cached_val, + ) + return x, weights, cached_key, cached_val + + def multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + """ + + seq_len, bsz, _ = x_proj.size() + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[..., 0:attention_dim] + k = x_proj[..., attention_dim : 2 * attention_dim] + value_dim = attention_dim // 2 + v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[..., 2 * attention_dim + value_dim :] + + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. + q = self.copy_query(q) # for diagnostics only, does nothing. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, seq_len, seq_len]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + seq_len, + seq_len, + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(seq_len, bsz, num_heads, head_dim) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len + ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, seq_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + if not torch.jit.is_scripting(): + if training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_output_weights = penalize_abs_values_gt( + attn_output_weights, limit=25.0, penalty=1.0e-04 + ) + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask, float("-inf") + ) + else: + attn_output_weights = attn_output_weights + attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + # If we are using chunk-wise attention mask and setting a limited + # num_left_chunks, the attention may only see the padding values which + # will also be masked out by `key_padding_mask`. At this circumstances, + # the whole column of `attn_output_weights` will be `-inf` + # (i.e. be `nan` after softmax). So we fill `0.0` at the masking + # positions to avoid invalid loss value below. + if ( + attn_mask is not None + and attn_mask.dtype == torch.bool + and key_padding_mask is not None + ): + if attn_mask.size(0) != 1: + attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) + combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + 1 + ).unsqueeze(2) + + attn_output_weights = attn_output_weights.view( + bsz, num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + return attn_output, attn_output_weights + + def streaming_multi_head_attention_forward( + self, + x_proj: Tensor, + pos: Tensor, + attention_dim: int, + num_heads: int, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + cached_key: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r""" + Args: + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. + attention_dim: dimension inside attention mechanism + num_heads: parallel attention heads. + out_proj_weight, out_proj_bias: the output projection weight and bias. + cached_key: cached attention key tensor of left context. + cached_val: cached attention value tensor of left context. + + Shape: + Inputs: + - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is + the attention dimension. Will be split into (query, key, value, pos). + - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence + length, N is the batch size, and A is the attention dim. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_weights: :math:`(N * H, S, S)` where N is the batch size, + H is the num-heads, S is the sequence length. + - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. + - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. + """ + if not self.is_pnnx: + seq_len, bsz, _ = x_proj.size() + assert seq_len == self.x_size, (seq_len, self.x_size) + else: + seq_len = self.x_size + bsz = 1 + + head_dim = attention_dim // num_heads + pos_dim = self.pos_dim # positional-encoding dim per head + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + + # self-attention + q = x_proj[:, :, 0:attention_dim] # (x_size, N, attention_dim) + # return q, q, q, q + k = x_proj[:, :, attention_dim : 2 * attention_dim] + # k is (x_size, N, attention_dim) + value_dim = attention_dim // 2 + v = x_proj[:, :, 2 * attention_dim : 2 * attention_dim + value_dim] + # v is (x_size, 0, attention_dim//2) + + # p is the position-encoding query, its dimension is num_heads*pos_dim.. + p = x_proj[:, :, 2 * attention_dim + value_dim :] + # p is (x_size, N, pos_dim * num_heads) + + if not self.is_pnnx: + left_context_len = cached_key.shape[0] + else: + assert cached_key.shape[0] == self.left_context_len, ( + cached_key.shape, + self.left_context_len, + ) + left_context_len = self.left_context_len + + assert left_context_len > 0, left_context_len + assert cached_key.shape[0] == cached_val.shape[0], ( + cached_key.shape, + cached_val.shape, + ) + # Note: We need to fix the Concat in ncnn + # cached_key is (1, 64, 192) in ncnn + # k is (16, 192) in ncnn + # Pad cached left contexts + k = torch.cat([cached_key, k], dim=0) + # (left_context_len + x_size, N, attention_dim) + + v = torch.cat([cached_val, v], dim=0) + # v: (left_context_len + x_size, N, attention_dim//2) + # Update cached left contexts + if not self.is_pnnx: + cached_key = k[-left_context_len:, ...] + cached_val = v[-left_context_len:, ...] + else: + cached_key = k[self.x_size :] + cached_val = v[self.x_size :] + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape, + left_context_len, + ) + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape, + left_context_len, + ) + + if not self.is_pnnx: + # The length of key and value + kv_len = k.shape[0] + else: + kv_len = left_context_len + self.x_size + assert kv_len == k.shape[0], (kv_len, k.shape) + + if not self.is_pnnx: + q = q.reshape(seq_len, bsz, num_heads, head_dim) + p = p.reshape(seq_len, bsz, num_heads, pos_dim) + k = k.reshape(kv_len, bsz, num_heads, head_dim) + + v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + # v is (bsz * num_heads, kv_len, head_dim//2) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) + p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + + seq_len2 = 2 * seq_len - 1 + left_context_len + pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) + # pos shape now: (batch, head, pos_dim, seq_len2) + else: + q = q.reshape(seq_len, num_heads, head_dim) + p = p.reshape(seq_len, num_heads, pos_dim) + k = k.reshape(kv_len, num_heads, head_dim) + # v = v.reshape(kv_len, num_heads, head_dim // 2).permute(1, 0, 2) + v = v.reshape(kv_len, num_heads, head_dim // 2) + v = self.my_permute_pqv(v) + # v is (num_heads, kv_len, head_dim//2) e.g., (8, 80, 12) + + # q = q.permute(1, 0, 2) # (head, time1, head_dim) + # p = p.permute(1, 0, 2) # (head, time1, pos_dim) + # k = k.permute(1, 2, 0) # (head, d_k, time2) + + q = self.my_permute_pqv(q) # (head, time1, head_dim), e.g., (8, 16, 24) + p = self.my_permute_pqv(p) # (head, time1, pos_dim), e.g., (8, 16, 4) + k = self.my_permute_k_pos(k) # (head, d_k, time2) e.g., (8, 24, 80) + + seq_len2 = 2 * seq_len - 1 + left_context_len + # pos = pos.reshape(seq_len2, num_heads, pos_dim).permute(1, 2, 0) + # pos shape now: (head, pos_dim, seq_len2) + + pos = pos.reshape(seq_len2, num_heads, pos_dim) + pos = self.my_permute_k_pos( + pos + ) # (head, pos_dim, seq_len2), e.g, (8, 4, 95) + + # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) ,e.g., (1, 8, 16, 95) + # [where seq_len2 represents relative position.] + pos_weights = torch.matmul(p, pos) + + # the following .as_strided() expression converts the last axis of pos_weights from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + + if not self.is_pnnx: + pos_weights = pos_weights.as_strided( + (bsz, num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2) - pos_weights.stride(3), + pos_weights.stride(3), + ), + storage_offset=pos_weights.stride(3) * (seq_len - 1), + ) + else: + pos_weights = pos_weights.as_strided( + (num_heads, seq_len, kv_len), + ( + pos_weights.stride(0), + pos_weights.stride(1) - pos_weights.stride(2), + pos_weights.stride(2), + ), + storage_offset=pos_weights.stride(2) * (seq_len - 1), + ) + + # caution: they are really scores at this point. + attn_output_weights = torch.matmul(q, k) + pos_weights + + # attn_output_weights: (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) + + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] + # (8, 16, 12) + + if not self.is_pnnx: + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, attention_dim // 2) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + else: + attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) + attn_output = attn_output.reshape(seq_len, bsz, attention_dim // 2) + # We have changed InnerProduct in ncnn to treat + # (seq_len, bsz, attention_dim//2) as + # (seq_len, attention_dim//2) + + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + return attn_output, attn_output_weights, cached_key, cached_val + + def forward2( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + Returns: + output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + """ + num_heads = self.num_heads + (seq_len, bsz, embed_dim) = x.shape + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + v = self.whiten_values2(v) # does nothing in the forward pass. + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not torch.jit.is_scripting(): + if random.random() < 0.001 or __name__ == "__main__": + self._print_attn_stats(attn_weights, attn_output) + + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output) + + def streaming_forward2( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Second forward function, where we re-use the attn_weights returned by the first forward function + but with different input. + Args: + x: input, of shape (seq_len, batch_size, embed_dim) + attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) + cached_val: cached attention value tensor of left context. + Returns: + - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) + - updated cached attention value tensor of left context. + """ + num_heads = self.num_heads + + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) + + if not self.is_pnnx: + (seq_len, bsz, embed_dim) = x.shape + else: + seq_len = self.x_size + bsz = 1 + embed_dim = self.embed_dim + + head_dim = self.attention_dim // num_heads + # v: (tgt_len, bsz, embed_dim // 2) + v = self.in_proj2(x) + + assert cached_val.shape[0] == self.left_context_len, ( + cached_val.shape[0], + self.left_context_len, + ) + + left_context_len = self.left_context_len + assert left_context_len > 0, left_context_len + v = torch.cat([cached_val, v], dim=0) + cached_val = v[-left_context_len:] + + seq_len2 = left_context_len + seq_len + if not self.is_pnnx: + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) + else: + v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2) + # v = v.permute(1, 0, 2) + v = self.my_permute_pqv(v) + + # now v: (bsz * num_heads, seq_len, head_dim // 2) + attn_output = torch.bmm(attn_weights, v) + + if not self.is_pnnx: + # attn_output: (bsz * num_heads, seq_len, head_dim) + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, bsz, self.attention_dim // 2) + ) + else: + attn_output = self.my_permute_pqv(attn_output) # (1, 0, 2) + attn_output = attn_output.reshape(seq_len, bsz, self.attention_dim // 2) + # We have changed InnerProduct in ncnn to ignore bsz + # when invoking self.out_proj2(attn_output) + + # returned value is of shape (seq_len, bsz, embed_dim), like x. + return self.out_proj2(attn_output), cached_val + + def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + # attn_weights: (batch_size * num_heads, seq_len, seq_len) + # attn_output: (bsz * num_heads, seq_len, head_dim) + (n, seq_len, head_dim) = attn_output.shape + num_heads = self.num_heads + bsz = n // num_heads + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_output = attn_output.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .reshape(bsz, num_heads, seq_len) + .mean(dim=(0, 2)) + ) + attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) + attn_output = attn_output.permute(1, 0, 2, 3).reshape( + num_heads, bsz * seq_len, head_dim + ) + attn_output_mean = attn_output.mean(dim=1, keepdim=True) + attn_output = attn_output - attn_output_mean + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( + bsz * seq_len + ) + # attn_covar: (num_heads, head_dim, head_dim) + # eigs, _ = torch.symeig(attn_covar) + # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + + attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) + embed_dim = self.in_proj2.weight.shape[1] + in_proj_covar = ( + self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 + ).mean(dim=(1, 2)) + out_proj_covar = ( + self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 + ).mean(dim=(0, 2)) + logging.info( + f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" + ) + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer model.""" + + def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(d_model, feedforward_dim) + self.balancer = ActivationBalancer( + feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 + ) + self.activation = DoubleSwish() + self.dropout = nn.Dropout(dropout) + self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.balancer(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + is_pnnx: bool = False, + x_size: int = 0, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + 2 * channels, + channel_dim=1, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + # Will pad cached left context + self.lorder = kernel_size - 1 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + max_abs=20.0, + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + self.is_pnnx = is_pnnx + self.x_size = x_size + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains bool in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + # 1D Depthwise Conv + # Make depthwise_conv causal by + # manualy padding self.lorder zeros to the left + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch: + (batch, #time), contains bool in masked positions. + cache: Cached left context for depthwise_conv, with shape of + (batch, channels, #kernel_size-1). Only used in real streaming decoding. + + Returns: + A tuple of 2 tensors: + - Output tensor (#time, batch, channels). + - New cached left context, with shape of (batch, channels, #kernel_size-1). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( + cache.shape, + (x.size(0), x.size(1), self.lorder), + ) + x = torch.cat([cache, x], dim=2) + + cache = x[:, :, self.x_size :] + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1), cache + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: float = 0.1, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-7)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer2_channels: + Number of channels in layer2 + layer3_channels: + Number of channels in layer3 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 7, in_channels + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + # After this layer (N, 1, T, C) -> (N, layer1_channels, T-2, C) + ActivationBalancer(layer1_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + # After this layer (N, layer1_channels, T-2, C) -> (N, layer2_channels, ((T-2) - 3)//2+1, (C-3)//2+1) + # i.e., (N, layer2_channels, (T-5)//2+1, (C-3)//2+1) + # i.e., (N, layer2_channels, (T-3)//2, (C-1)//2) + ActivationBalancer(layer2_channels, channel_dim=1), + DoubleSwish(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + # After this layer, (N, layer2_channels, (T-3)//2, (C-1)//2) + # -> + # (N, layer3_channels, (T-3)//2-2, ((C-1)//2 - 3)//2 + 1) + # (N, layer3_channels, (T-7)//2, (C-3)//4) + ActivationBalancer(layer3_channels, channel_dim=1), + DoubleSwish(), + ) + out_height = (((in_channels - 1) // 2) - 1) // 2 + self.out = ScaledLinear(out_height * layer3_channels, out_channels) + self.dropout = nn.Dropout(dropout) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, (T-7)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) + # Now x is of shape (N, (T-7)//2, odim) + x = self.dropout(x) + return x + + +def _test_zipformer_main(): + feature_dim = 50 + batch_size = 5 + seq_len = 47 + feature_dim = 50 + # Just make sure the forward pass runs. + + c = Zipformer( + num_features=feature_dim, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + decode_chunk_size=4, + ) + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +def _test_conv2d_subsampling(): + num_features = 80 + encoder_dims = 384 + dropout = 0.1 + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + for i in range(20, 40): + x = torch.rand(2, i, num_features) + y = encoder_embed(x) + assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + + +def _test_pooling_module(): + N, S, C = 2, 12, 32 + chunk_len = 4 + m = PoolingModule(d_model=C) + + # test chunk-wise forward with padding_mask + x = torch.randn(S, N, C) + y = m(x) + cached_len = torch.zeros(N, dtype=torch.int32) + cached_avg = torch.zeros(N, C) + for i in range(S // chunk_len): + start = i * chunk_len + end = start + chunk_len + x_chunk = x[start:end] + y_chunk, cached_len, cached_avg = m.streaming_forward( + x_chunk, + cached_len=cached_len, + cached_avg=cached_avg, + ) + assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) + + +def _test_state_stack_unstack(): + m = Zipformer( + num_features=80, + encoder_dims=(64, 96), + encoder_unmasked_dims=(48, 64), + nhead=(4, 4), + zipformer_downsampling_factors=(4, 8), + num_left_chunks=2, + decode_chunk_size=8, + ) + s1 = m.get_init_state() + s2 = m.get_init_state() + states = stack_states([s1, s2]) + new_s1, new_s2 = unstack_states(states) + for i in range(m.num_encoders * 7): + for x, y in zip(s1[i], new_s1[i]): + assert torch.equal(x, y) + for x, y in zip(s2[i], new_s2[i]): + assert torch.equal(x, y) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main() + _test_conv2d_subsampling() + _test_pooling_module() + _test_state_stack_unstack() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e61367134..7b651a632 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -594,18 +594,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -615,9 +611,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index abe249c7b..b0abad5ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,10 +1154,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index 553b7d092..bb55ed6bb 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -30,6 +30,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse.cut import Cut from lhotse.utils import fix_random_seed from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP @@ -645,8 +646,23 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom( model=model, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md index 94d4ed6a3..b1e01a218 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md @@ -1,4 +1,4 @@ Please visit - + for how to run this recipe. diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 804713a20..8d379d1fa 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -325,18 +325,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -346,9 +342,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 9511ca6d7..806b68f40 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -322,18 +322,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -343,9 +339,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 643238f1b..42125e19f 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index 9a6363629..b05fe2a4d 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 56ad558c6..5570b30ae 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -380,18 +380,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -401,9 +397,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 7d0ea78bb..33c0bf199 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -471,18 +471,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -490,9 +486,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py index 1463f8f67..72338bade 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py @@ -410,17 +410,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -430,9 +426,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 219c96d60..4434aae62 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( test_set_wers = dict() test_set_cers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" - ) + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -413,9 +409,7 @@ def save_results( results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" - ) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" with open(cers_filename, "w") as f: cer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -426,9 +420,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh index d9938fa63..c5d498d74 100755 --- a/egs/tal_csasr/ASR/prepare.sh +++ b/egs/tal_csasr/ASR/prepare.sh @@ -12,9 +12,12 @@ stop_stage=100 # directories and files. If not, they will be downloaded # by this script automatically. # -# - $dl_dir/tal_csasr +# - $dl_dir/TALCS_corpus # You can find three directories:train_set, dev_set, and test_set. # You can get it from https://ai.100tal.com/dataset +# - dev_set +# - test_set +# - train_set # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -44,7 +47,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" # Before you run this script, you must get the TAL_CSASR dataset # from https://ai.100tal.com/dataset - mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr + if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then + mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr + fi # If you have pre-downloaded it to /path/to/TALCS_corpus, # you can create a symlink @@ -116,7 +121,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi # Prepare text. - # Note: in Linux, you can install jq with the following command: + # Note: in Linux, you can install jq with the following command: # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # 2. chmod +x ./jq # 3. cp jq /usr/bin diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index bf91fef7e..3bfb832fb 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -509,18 +509,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -530,9 +526,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 38f2ae83c..abba9d403 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index 01f08ce59..fb0e3116b 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -354,18 +354,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -375,9 +371,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/timit/ASR/README.md b/egs/timit/ASR/README.md index f10bfccfd..d493fc479 100644 --- a/egs/timit/ASR/README.md +++ b/egs/timit/ASR/README.md @@ -1,3 +1,3 @@ -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 04602ea2e..823b33ae5 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -516,18 +516,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -537,9 +533,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 7bd1177bd..32d5738b1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -489,18 +489,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -510,9 +506,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index c7863415b..3a4dc3cb8 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -466,9 +466,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -477,9 +475,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -489,9 +485,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py index 6a67e26f8..b77f734e3 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -701,18 +701,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -722,9 +718,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py index ace792e13..e334e690a 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/yesno/ASR/README.md b/egs/yesno/ASR/README.md index 7257bad9a..38b491fc6 100644 --- a/egs/yesno/ASR/README.md +++ b/egs/yesno/ASR/README.md @@ -10,5 +10,5 @@ get the following WER: ``` Please refer to - + for detailed instructions. diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 9f680f83d..600f09f2b 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -74,7 +74,9 @@ class MmiTrainingGraphCompiler(object): # CAUTION: The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. - P.labels[P.labels >= first_token_disambig_id] = 0 + labels = P.labels.clone() + labels[labels >= first_token_disambig_id] = 0 + P.labels = labels P = k2.remove_epsilon(P) P = k2.arc_sort(P) diff --git a/icefall/utils.py b/icefall/utils.py index 99e51a2a9..2358ed02f 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,5 +1,6 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo, +# Zengwei Yao) # # See ../../LICENSE for clarification regarding multiple authors # @@ -453,11 +454,32 @@ def store_transcripts_and_timestamps( for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) + if len(time_ref) > 0: - s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" + if isinstance(time_ref[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" print(f"{cut_id}:\ttimestamp_ref={s}", file=f) - s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" - print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + + if len(time_hyp) > 0: + if isinstance(time_hyp[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) def write_error_stats( @@ -624,9 +646,18 @@ def write_error_stats( def write_error_stats_with_timestamps( f: TextIO, test_set_name: str, - results: List[Tuple[str, List[str], List[str], List[float], List[float]]], + results: List[ + Tuple[ + str, + List[str], + List[str], + List[Union[float, Tuple[float, float]]], + List[Union[float, Tuple[float, float]]], + ] + ], enable_log: bool = True, -) -> Tuple[float, float, float]: + with_end_time: bool = False, +) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]: """Write statistics based on predicted results and reference transcripts as well as their timestamps. @@ -659,6 +690,8 @@ def write_error_stats_with_timestamps( enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. + with_end_time: + Whether use end timestamps. Returns: Return total word error rate and mean delay. @@ -704,7 +737,15 @@ def write_error_stats_with_timestamps( words[ref_word][0] += 1 num_corr += 1 if has_time: - all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) + if with_end_time: + all_delay.append( + ( + time_hyp[p_hyp][0] - time_ref[p_ref][0], + time_hyp[p_hyp][1] - time_ref[p_ref][1], + ) + ) + else: + all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) p_hyp += 1 p_ref += 1 if has_time: @@ -716,16 +757,39 @@ def write_error_stats_with_timestamps( ins_errs = sum(ins.values()) del_errs = sum(dels.values()) tot_errs = sub_errs + ins_errs + del_errs - tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len)) - mean_delay = "inf" - var_delay = "inf" + if with_end_time: + mean_delay = (float("inf"), float("inf")) + var_delay = (float("inf"), float("inf")) + else: + mean_delay = float("inf") + var_delay = float("inf") num_delay = len(all_delay) if num_delay > 0: - mean_delay = sum(all_delay) / num_delay - var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay - mean_delay = "%.3f" % mean_delay - var_delay = "%.3f" % var_delay + if with_end_time: + all_delay_start = [i[0] for i in all_delay] + mean_delay_start = sum(all_delay_start) / num_delay + var_delay_start = ( + sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay + ) + + all_delay_end = [i[1] for i in all_delay] + mean_delay_end = sum(all_delay_end) / num_delay + var_delay_end = ( + sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay + ) + + mean_delay = ( + float("%.3f" % mean_delay_start), + float("%.3f" % mean_delay_end), + ) + var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end)) + else: + mean_delay = sum(all_delay) / num_delay + var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay + mean_delay = float("%.3f" % mean_delay) + var_delay = float("%.3f" % var_delay) if enable_log: logging.info( @@ -734,7 +798,8 @@ def write_error_stats_with_timestamps( f"{del_errs} del, {sub_errs} sub ]" ) logging.info( - f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa + f"[{test_set_name}] %symbol-delay mean (s): " + f"{mean_delay}, variance: {var_delay} " # noqa f"computed on {num_delay} correct words" ) @@ -817,7 +882,8 @@ def write_error_stats_with_timestamps( hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate), float(mean_delay), float(var_delay) + + return tot_err_rate, mean_delay, var_delay class MetricsTracker(collections.defaultdict): @@ -1395,3 +1461,306 @@ def is_module_available(*modules: str) -> bool: import importlib return all(importlib.util.find_spec(m) is not None for m in modules) + + +def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): + """For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed the given allow_max_frames. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + allowed_max_frames: + The allowed max number of frames in batch. + """ + features = batch["inputs"] + supervisions = batch["supervisions"] + + N, T, _ = features.size() + assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max()) + keep_num_utt = allowed_max_frames // T + + if keep_num_utt >= N: + return batch + + # Note: we assume the samples in batch is sorted descendingly by length + logging.info( + f"Filtering uneven-sized batch, original batch size is {N}, " + f"retained batch size is {keep_num_utt}." + ) + batch["inputs"] = features[:keep_num_utt] + for k, v in supervisions.items(): + assert len(v) == N, (len(v), N) + batch["supervisions"][k] = v[:keep_num_utt] + + return batch + + +def parse_bpe_start_end_pairs( + tokens: List[str], is_first_token: List[bool] +) -> List[Tuple[int, int]]: + """Parse pairs of start and end frame indexes for each word. + + Args: + tokens: + List of BPE tokens. + is_first_token: + List of bool values, which indicates whether it is the first token, + i.e., not repeat or blank. + + Returns: + List of (start-frame-index, end-frame-index) pairs for each word. + """ + assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token)) + + start_token = b"\xe2\x96\x81".decode() # '_' + blank_token = "" + + non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token] + num_non_blank = len(non_blank_idx) + + pairs = [] + start = -1 + end = -1 + for j in range(num_non_blank): + # The index in all frames + i = non_blank_idx[j] + + found_start = False + if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)): + found_start = True + if tokens[i] == start_token: + if j == num_non_blank - 1: + # It is the last non-blank token + found_start = False + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_start = False + if found_start: + start = i + + if start != -1: + found_end = False + if j == num_non_blank - 1: + # It is the last non-blank token + found_end = True + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_end = True + if found_end: + end = i + + if start != -1 and end != -1: + if not all([tokens[t] == start_token for t in range(start, end + 1)]): + # except the case of all start_token + pairs.append((start, end)) + # Reset start and end + start = -1 + end = -1 + + return pairs + + +def parse_bpe_timestamps_and_texts( + best_paths: k2.Fsa, sp: spm.SentencePieceProcessor +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Its attribtutes `labels` and `aux_labels` + are both BPE tokens. + sp: + The BPE model. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous()) + + # remove -1's. + all_aux_labels = aux_labels.remove_values_eq(-1) + # len(all_aux_labels[i]) is equal to the number of frames + all_aux_labels = all_aux_labels.tolist() + + # remove 0's and -1's. + out_aux_labels = aux_labels.remove_values_leq(0) + # len(out_aux_labels[i]) is equal to the number of output BPE tokens + out_aux_labels = out_aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i in range(len(labels)): + tokens = sp.id_to_piece(labels[i]) + words = sp.decode(out_aux_labels[i]).split() + + # Indicates whether it is the first token, i.e., not-repeat and not-blank. + is_first_token = [a != 0 for a in all_aux_labels[i]] + index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token) + assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens) + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words + + +def parse_timestamps_and_texts( + best_paths: k2.Fsa, word_table: k2.SymbolTable +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. + word_table: + The word symbol table. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + # [utt][words] + word_ids = get_texts(best_paths) + + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_shape = shape.compose(best_paths.aux_labels.shape) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous()) + aux_labels = aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i, (label, aux_label) in enumerate(zip(labels, aux_labels)): + num_arcs = len(label) + # The last arc of aux_label is the arc entering the final state + assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label)) + + index_pairs = [] + start = -1 + end = -1 + for arc in range(num_arcs): + # len(aux_label[arc]) is 0 or 1 + if label[arc] != 0 and len(aux_label[arc]) != 0: + if start != -1 and end != -1: + index_pairs.append((start, end)) + start = arc + if label[arc] != 0: + end = arc + if start != -1 and end != -1: + index_pairs.append((start, end)) + + words = [word_table[w] for w in word_ids[i]] + assert len(index_pairs) == len(words), (len(index_pairs), len(words)) + + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words + + +def parse_fsa_timestamps_and_texts( + best_paths: k2.Fsa, + sp: Optional[spm.SentencePieceProcessor] = None, + word_table: Optional[k2.SymbolTable] = None, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Parse timestamps (in seconds) and texts for given decoded fsa paths. + Currently it supports two cases: + (1) ctc-decoding, the attribtutes `labels` and `aux_labels` + are both BPE tokens. In this case, sp should be provided. + (2) HLG-based 1best, the attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens; attribute `aux_labels` is the word index. + In this case, word_table should be provided. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + sp: + The BPE model. + word_table: + The word symbol table. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + if sp is not None: + assert word_table is None, "word_table is not needed if sp is provided." + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts( + best_paths=best_paths, sp=sp + ) + elif word_table is not None: + assert sp is None, "sp is not needed if word_table is provided." + utt_index_pairs, utt_words = parse_timestamps_and_texts( + best_paths=best_paths, word_table=word_table + ) + else: + raise ValueError("Either sp or word_table should be provided.") + + utt_time_pairs = [] + for utt in utt_index_pairs: + start = convert_timestamp( + frames=[i[0] for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + + return utt_time_pairs, utt_words diff --git a/requirements-ci.txt b/requirements-ci.txt index b8e49899e..50d4e5e3f 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -22,5 +22,6 @@ typeguard==2.13.3 multi_quantization onnx +onnxmltools onnxruntime kaldifst diff --git a/test/test_parse_timestamp.py b/test/test_parse_timestamp.py new file mode 100755 index 000000000..92bfb49c6 --- /dev/null +++ b/test/test_parse_timestamp.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + +from icefall.lexicon import Lexicon +from icefall.utils import parse_bpe_timestamps_and_texts, parse_timestamps_and_texts + +ICEFALL_DIR = Path(__file__).resolve().parent.parent + + +def test_parse_bpe_timestamps_and_texts(): + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500" + if not lang_dir.is_dir(): + print(f"{lang_dir} does not exist.") + return + + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + + text_1 = "HELLO WORLD" + token_ids_1 = sp.encode(text_1, out_type=int) + # out_type=str: ['_HE', 'LL', 'O', '_WORLD'] + # out_type=int: [22, 58, 24, 425] + + # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0] + labels_1 = ( + token_ids_1[0:1] * 2 + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] * 3 + + [0] * 2 + ) + # [22, 0, 58, 24, 0, 0, 425, 0, 0, 0, 0, -1] + aux_labels_1 = ( + token_ids_1[0:1] + + [0] + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] + + [0] * 4 + + [-1] + ) + fsa_1 = k2.linear_fsa(labels_1) + fsa_1.aux_labels = torch.tensor(aux_labels_1).to(torch.int32) + + text_2 = "SAY GOODBYE" + token_ids_2 = sp.encode(text_2, out_type=int) + # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E'] + # out_type=int: [289, 286, 41, 16, 11] + + # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0] + labels_2 = ( + token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2 + ) + # [289, 0, 0, 286, 0, 41, 16, 11, 0, 0, -1] + aux_labels_2 = ( + token_ids_2[0:1] + + [0] * 2 + + token_ids_2[1:2] + + [0] + + token_ids_2[2:5] + + [0] * 2 + + [-1] + ) + fsa_2 = k2.linear_fsa(labels_2) + fsa_2.aux_labels = torch.tensor(aux_labels_2).to(torch.int32) + + fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2]) + + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(fsa_vec, sp) + assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0] + assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0] + assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1] + assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1] + + +def test_parse_timestamps_and_texts(): + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500" + if not lang_dir.is_dir(): + print(f"{lang_dir} does not exist.") + return + + lexicon = Lexicon(lang_dir) + + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + word_table = lexicon.word_table + + text_1 = "HELLO WORLD" + token_ids_1 = sp.encode(text_1, out_type=int) + # out_type=str: ['_HE', 'LL', 'O', '_WORLD'] + # out_type=int: [22, 58, 24, 425] + word_ids_1 = [word_table[s] for s in text_1.split()] # [79677, 196937] + # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0] + labels_1 = ( + token_ids_1[0:1] * 2 + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] * 3 + + [0] * 2 + ) + # [[79677], [], [], [], [], [], [196937], [], [], [], [], []] + aux_labels_1 = [word_ids_1[0:1]] + [[]] * 5 + [word_ids_1[1:2]] + [[]] * 5 + + fsa_1 = k2.linear_fsa(labels_1) + fsa_1.aux_labels = k2.RaggedTensor(aux_labels_1) + + text_2 = "SAY GOODBYE" + token_ids_2 = sp.encode(text_2, out_type=int) + # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E'] + # out_type=int: [289, 286, 41, 16, 11] + word_ids_2 = [word_table[s] for s in text_2.split()] # [154967, 72079] + # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0] + labels_2 = ( + token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2 + ) + # [[154967], [], [], [72079], [], [], [], [], [], [], []] + aux_labels_2 = [word_ids_2[0:1]] + [[]] * 2 + [word_ids_2[1:2]] + [[]] * 7 + + fsa_2 = k2.linear_fsa(labels_2) + fsa_2.aux_labels = k2.RaggedTensor(aux_labels_2) + + fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2]) + + utt_index_pairs, utt_words = parse_timestamps_and_texts(fsa_vec, word_table) + assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0] + assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0] + assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1] + assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1] + + +if __name__ == "__main__": + test_parse_bpe_timestamps_and_texts() + test_parse_timestamps_and_texts()