diff --git a/.flake8 b/.flake8 index 22cd63b3d..609fa2c03 100644 --- a/.flake8 +++ b/.flake8 @@ -9,7 +9,7 @@ per-file-ignores = egs/*/ASR/pruned_transducer_stateless*/*.py: E501, egs/*/ASR/*/optim.py: E501, egs/*/ASR/*/scaling.py: E501, - egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203 + egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203 egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203 egs/librispeech/ASR/conformer_ctc2/*py: E501, egs/librispeech/ASR/RESULTS.md: E999, diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh index a4a6cd8d7..bb7c7dfdc 100755 --- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh +++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh @@ -4,6 +4,8 @@ # The computed features are saved to ~/tmp/fbank-libri and are # cached for later runs +set -e + export PYTHONPATH=$PWD:$PYTHONPATH echo $PYTHONPATH diff --git a/.github/scripts/download-gigaspeech-dev-test-dataset.sh b/.github/scripts/download-gigaspeech-dev-test-dataset.sh index b9464de9f..f3564efc7 100755 --- a/.github/scripts/download-gigaspeech-dev-test-dataset.sh +++ b/.github/scripts/download-gigaspeech-dev-test-dataset.sh @@ -6,6 +6,8 @@ # You will find directories `~/tmp/giga-dev-dataset-fbank` after running # this script. +set -e + mkdir -p ~/tmp cd ~/tmp diff --git a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh index 3efcc13e3..11704526c 100755 --- a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh +++ b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh @@ -7,6 +7,8 @@ # You will find directories ~/tmp/download/LibriSpeech after running # this script. +set -e + mkdir ~/tmp/download cd egs/librispeech/ASR ln -s ~/tmp/download . diff --git a/.github/scripts/install-kaldifeat.sh b/.github/scripts/install-kaldifeat.sh index 6666a5064..de30f7dfe 100755 --- a/.github/scripts/install-kaldifeat.sh +++ b/.github/scripts/install-kaldifeat.sh @@ -3,6 +3,8 @@ # This script installs kaldifeat into the directory ~/tmp/kaldifeat # which is cached by GitHub actions for later runs. +set -e + mkdir -p ~/tmp cd ~/tmp git clone https://github.com/csukuangfj/kaldifeat diff --git a/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh b/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh index e0b87e0fc..1b48aae27 100755 --- a/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh +++ b/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh @@ -4,6 +4,8 @@ # to egs/librispeech/ASR/download/LibriSpeech and generates manifest # files in egs/librispeech/ASR/data/manifests +set -e + cd egs/librispeech/ASR [ ! -e download ] && ln -s ~/tmp/download . mkdir -p data/manifests diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index 631707ad9..e70a1848d 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -40,7 +42,7 @@ for sym in 1 2 3; do --lang-dir $repo/data/lang_char \ $repo/test_wavs/BAC009S0764W0121.wav \ $repo/test_wavs/BAC009S0764W0122.wav \ - $rep/test_wavs/BAC009S0764W0123.wav + $repo/test_wavs/BAC009S0764W0123.wav done for method in modified_beam_search beam_search fast_beam_search; do @@ -53,7 +55,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --lang-dir $repo/data/lang_char \ $repo/test_wavs/BAC009S0764W0121.wav \ $repo/test_wavs/BAC009S0764W0122.wav \ - $rep/test_wavs/BAC009S0764W0123.wav + $repo/test_wavs/BAC009S0764W0123.wav done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh index 528d04cd1..c8d9c6b77 100755 --- a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh +++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml new file mode 100755 index 000000000..b89055c72 --- /dev/null +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -0,0 +1,203 @@ +#!/usr/bin/env bash +# +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s pretrained-iter-468000-avg-16.pt pretrained.pt +ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt +popd + +log "Install ncnn and pnnx" + +# We are using a modified ncnn here. Will try to merge it to the official repo +# of ncnn +git clone https://github.com/csukuangfj/ncnn +pushd ncnn +git submodule init +git submodule update python/pybind11 +python3 setup.py bdist_wheel +ls -lh dist/ +pip install dist/*.whl +cd tools/pnnx +mkdir build +cd build +cmake .. +make -j4 pnnx + +./src/pnnx || echo "pass" + +popd + +log "Test exporting to pnnx format" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --pnnx 1 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +./lstm_transducer_stateless2/ncnn-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + +./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav + + + +log "Test exporting with torch.jit.trace()" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit-trace 1 + +log "Decode with models exported by torch.jit.trace()" + +./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Test exporting to ONNX" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --onnx 1 + +log "Decode with ONNX models " + +./lstm_transducer_stateless2/streaming-onnx-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo//exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1089-134686-0001.wav + +./lstm_transducer_stateless2/streaming-onnx-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo//exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1221-135766-0001.wav + +./lstm_transducer_stateless2/streaming-onnx-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo//exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1221-135766-0002.wav + + + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./lstm_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./lstm_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then + mkdir -p lstm_transducer_stateless2/exp + ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./lstm_transducer_stateless2/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir lstm_transducer_stateless2/exp + done + + rm lstm_transducer_stateless2/exp/*.pt +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index bd816c2d6..dafea56db 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index 6b5b51bd7..ae2bb6822 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -11,10 +13,14 @@ cd egs/librispeech/ASR repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29 log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt" +popd + log "Display test files" tree $repo/ soxi $repo/test_wavs/*.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 62ea02c47..172d7ad4c 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index bdc8a3838..880767443 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -58,17 +60,17 @@ log "Decode with ONNX models" --jit-filename $repo/exp/cpu_jit.pt \ --onnx-encoder-filename $repo/exp/encoder.onnx \ --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx - -./pruned_transducer_stateless3/onnx_check_all_in_one.py \ - --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-all-in-one-filename $repo/exp/all_in_one.onnx + --onnx-joiner-filename $repo/exp/joiner.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx ./pruned_transducer_stateless3/onnx_pretrained.py \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --encoder-model-filename $repo/exp/encoder.onnx \ --decoder-model-filename $repo/exp/decoder.onnx \ --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index c893bc45a..c6a781318 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index d9dc34e48..af37102d5 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index c22660d0a..5b8ed396b 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 96a072c46..96c320616 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -10,7 +12,6 @@ cd egs/librispeech/ASR repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 git lfs install -git clone $repo log "Downloading pre-trained model from $repo_url" git clone $repo_url diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index dcc99d62e..209d4814f 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 9622224c9..34ff76fe4 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh index 168aee766..75650c2d3 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh index 9211b22eb..bcc2d74cb 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 4a1dc1a7e..d3e40315a 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 5f8a5b3a5..cfa006776 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh new file mode 100755 index 000000000..2d237dcf2 --- /dev/null +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -0,0 +1,124 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/wenetspeech/ASR + +repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s pretrained_epoch_10_avg_2.pt pretrained.pt +ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt +popd + +log "Test exporting to ONNX format" + +./pruned_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --lang-dir $repo/data/lang_char \ + --epoch 99 \ + --avg 1 \ + --onnx 1 + +log "Export to torchscript model" + +./pruned_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --lang-dir $repo/data/lang_char \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +./pruned_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --lang-dir $repo/data/lang_char \ + --epoch 99 \ + --avg 1 \ + --jit-trace 1 + +ls -lh $repo/exp/*.onnx +ls -lh $repo/exp/*.pt + +log "Decode with ONNX models" + +./pruned_transducer_stateless2/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder.onnx \ + --onnx-decoder-filename $repo/exp/decoder.onnx \ + --onnx-joiner-filename $repo/exp/joiner.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + +./pruned_transducer_stateless2/onnx_pretrained.py \ + --tokens $repo/data/lang_char/tokens.txt \ + --encoder-model-filename $repo/exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav + +log "Decode with models exported by torch.jit.trace()" + +./pruned_transducer_stateless2/jit_pretrained.py \ + --tokens $repo/data/lang_char/tokens.txt \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav + +./pruned_transducer_stateless2/jit_pretrained.py \ + --tokens $repo/data/lang_char/tokens.txt \ + --encoder-model-filename $repo/exp/encoder_jit_script.pt \ + --decoder-model-filename $repo/exp/decoder_jit_script.pt \ + --joiner-model-filename $repo/exp/joiner_jit_script.pt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless2/pretrained.py \ + --checkpoint $repo/exp/epoch-99.pt \ + --lang-dir $repo/data/lang_char \ + --decoding-method greedy_search \ + --max-sym-per-frame $sym \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless2/pretrained.py \ + --decoding-method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/epoch-99.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav +done diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index e684e598e..e46b01a08 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -69,7 +69,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index dc33751d3..c631927fa 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 291f2bc71..5df710006 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index b04718f86..24c062442 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index bb3d74e55..29215ec25 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml new file mode 100644 index 000000000..dd67771ba --- /dev/null +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -0,0 +1,136 @@ +name: run-librispeech-lstm-transducer2-2022-09-03 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_lstm_transducer_stateless2_2022_09_03: + if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml + + - name: Display decoding results for lstm_transducer_stateless2 + if: github.event_name == 'schedule' + shell: bash + run: | + cd egs/librispeech/ASR + tree lstm_transducer_stateless2/exp + cd lstm_transducer_stateless2/exp + echo "===greedy search===" + find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for lstm_transducer_stateless2 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 + path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index 3b6e11a31..66a2c240b 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index 9ce8244da..55428861c 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index e05b04bee..f520405e1 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -68,7 +68,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index f4c6bf507..9bc6a481f 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -58,7 +58,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index 348a68095..7a0f30b0f 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -67,7 +67,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index d1369c2b1..797f3fe50 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -67,7 +67,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 9d095a0aa..29e665881 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -58,7 +58,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index 868fe6fbe..6193f28e7 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -58,7 +58,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 78c1ca059..32208076c 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -67,7 +67,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 959e57278..965d0f655 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -58,7 +58,7 @@ jobs: with: path: | ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }} + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 - name: Install kaldifeat if: steps.my-cache.outputs.cache-hit != 'true' diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml new file mode 100644 index 000000000..d96a3bfe6 --- /dev/null +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -0,0 +1,80 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-wenetspeech-pruned-transducer-stateless2 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_librispeech_pruned_transducer_stateless3_2022_05_13: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 239a0280c..90459bc1c 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -29,8 +29,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-18.04, macos-latest] - python-version: [3.7, 3.9] + os: [ubuntu-latest] + python-version: [3.8] fail-fast: false steps: diff --git a/.gitignore b/.gitignore index 1dbf8f395..406deff6a 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ log *.bak *-bak *bak.py +*.param +*.bin diff --git a/docker/README.md b/docker/README.md index 0c8cb0ed9..0a39b7a49 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,24 +1,114 @@ # icefall dockerfile -We provide a dockerfile for some users, the configuration of dockerfile is : Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8-python3.8. You can use the dockerfile by following the steps: +2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. -## Building images locally +If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8. + +Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0. + +You can check the highest CUDA version within your NVIDIA driver's support with the `nvidia-smi` command below. In this example, the highest CUDA version is 11.0, i.e. case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. ```bash -cd docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8 -docker build -t icefall/pytorch1.7.1:latest -f ./Dockerfile ./ +$ nvidia-smi +Tue Sep 20 00:26:13 2022 ++-----------------------------------------------------------------------------+ +| NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 | +|-------------------------------+----------------------+----------------------+ +| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|===============================+======================+======================| +| 0 TITAN RTX On | 00000000:03:00.0 Off | N/A | +| 41% 31C P8 4W / 280W | 16MiB / 24219MiB | 0% Default | +| | | N/A | ++-------------------------------+----------------------+----------------------+ +| 1 TITAN RTX On | 00000000:04:00.0 Off | N/A | +| 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default | +| | | N/A | ++-------------------------------+----------------------+----------------------+ + ++-----------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=============================================================================| +| 0 N/A N/A 2085 G /usr/lib/xorg/Xorg 9MiB | +| 0 N/A N/A 2240 G /usr/bin/gnome-shell 4MiB | +| 1 N/A N/A 2085 G /usr/lib/xorg/Xorg 4MiB | ++-----------------------------------------------------------------------------+ + ``` -## Using built images -Sample usage of the GPU based images: +## Building images locally +If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly. +For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details. + +```dockerfile +ENV http_proxy=http://aaa.bb.cc.net:8080 \ + https_proxy=http://aaa.bb.cc.net:8080 +``` + +Then, proceed with these commands. + +### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3: + +```bash +cd docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8 +docker build -t icefall/pytorch1.12.1 . +``` + +### If you are case (b), i.e. your NVIDIA driver can only support CUDA versions 11.0 <= x < 11.3: +```bash +cd docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8 +docker build -t icefall/pytorch1.7.1 . +``` + +## Running your built local image +Sample usage of the GPU based images. These commands are written with case (a) in mind, so please make the necessary changes to your image name if you are case (b). Note: use [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) to run the GPU images. ```bash -docker run -it --runtime=nvidia --name=icefall_username --gpus all icefall/pytorch1.7.1:latest +docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall/pytorch1.12.1 ``` -Sample usage of the CPU based images: +### Tips: +1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/docker}:{/path/in/host/machine}`. + +2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`. + +Overall, your docker run command should look like this. ```bash -docker run -it icefall/pytorch1.7.1:latest /bin/bash -``` \ No newline at end of file +docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/docker}:{/path/in/host/machine} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1 +``` + +You can explore more docker run options [here](https://docs.docker.com/engine/reference/commandline/run/) to suit your environment. + +### Linking to icefall in your host machine + +If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container. + +Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine. +Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below. + +Use these commands once you are inside the container. + +```bash +rm -r /workspace/icefall +ln -s {/path/in/docker/to/icefall} /workspace/icefall +``` + +## Starting another session in the same running container. +```bash +docker exec -it icefall /bin/bash +``` + +## Restarting a killed container that has been run before. +```bash +docker start -ai icefall +``` + +## Sample usage of the CPU based images: +```bash +docker run -it icefall /bin/bash +``` diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile new file mode 100644 index 000000000..db4dda864 --- /dev/null +++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile @@ -0,0 +1,72 @@ +FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel + +# ENV http_proxy=http://aaa.bbb.cc.net:8080 \ +# https_proxy=http://aaa.bbb.cc.net:8080 + +# install normal source +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + g++ \ + make \ + automake \ + autoconf \ + bzip2 \ + unzip \ + wget \ + sox \ + libtool \ + git \ + subversion \ + zlib1g-dev \ + gfortran \ + ca-certificates \ + patch \ + ffmpeg \ + valgrind \ + libssl-dev \ + vim \ + curl + +# cmake +RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ + cd /opt && \ + tar -zxvf cmake-3.18.0.tar.gz && \ + cd cmake-3.18.0 && \ + ./bootstrap && \ + make && \ + make install && \ + rm -rf cmake-3.18.0.tar.gz && \ + find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ + cd - + +# flac +RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ + cd /opt && \ + xz -d flac-1.3.2.tar.xz && \ + tar -xvf flac-1.3.2.tar && \ + cd flac-1.3.2 && \ + ./configure && \ + make && make install && \ + rm -rf flac-1.3.2.tar && \ + find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ + cd - + +RUN pip install kaldiio graphviz && \ + conda install -y -c pytorch torchaudio + +#install k2 from source +RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ + cd /opt/k2 && \ + python3 setup.py install && \ + cd - + +# install lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall \ No newline at end of file diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile index 746c2c4f3..7a14a00ad 100644 --- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile +++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile @@ -1,7 +1,13 @@ FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel -# install normal source +# ENV http_proxy=http://aaa.bbb.cc.net:8080 \ +# https_proxy=http://aaa.bbb.cc.net:8080 +RUN rm /etc/apt/sources.list.d/cuda.list && \ + rm /etc/apt/sources.list.d/nvidia-ml.list && \ + apt-key del 7fa2af80 + +# install normal source RUN apt-get update && \ apt-get install -y --no-install-recommends \ g++ \ @@ -21,20 +27,25 @@ RUN apt-get update && \ patch \ ffmpeg \ valgrind \ - libssl-dev \ - vim && \ - rm -rf /var/lib/apt/lists/* + libssl-dev \ + vim \ + curl - -RUN mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ +# Add new keys and reupdate +RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub | apt-key add - && \ + curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ + echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ + echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \ + rm -rf /var/lib/apt/lists/* && \ + mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \ mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \ mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \ mv /opt/conda/lib/libnvrtc.so.11.0 /opt/libnvrtc.so.11.1.bak && \ - mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \ - mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak + # mv /opt/conda/lib/libnvToolsExt.so.1 /opt/libnvToolsExt.so.1.bak && \ + mv /opt/conda/lib/libcudart.so.11.0 /opt/libcudart.so.11.0.bak && \ + apt-get update && apt-get -y upgrade # cmake - RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ cd /opt && \ tar -zxvf cmake-3.18.0.tar.gz && \ @@ -45,11 +56,7 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \ rm -rf cmake-3.18.0.tar.gz && \ find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - - -#kaldiio - -RUN pip install kaldiio - + # flac RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \ cd /opt && \ @@ -62,15 +69,8 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \ cd - -# graphviz -RUN pip install graphviz - -# kaldifeat -RUN git clone https://github.com/csukuangfj/kaldifeat.git /opt/kaldifeat && \ - cd /opt/kaldifeat && \ - python setup.py install && \ - cd - - +RUN pip install kaldiio graphviz && \ + conda install -y -c pytorch torchaudio=0.7.1 #install k2 from source RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ @@ -79,14 +79,13 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \ cd - # install lhotse -RUN pip install torchaudio==0.7.2 -RUN pip install git+https://github.com/lhotse-speech/lhotse -#RUN pip install lhotse +RUN pip install git+https://github.com/lhotse-speech/lhotse -# install icefall -RUN git clone https://github.com/k2-fsa/icefall && \ - cd icefall && \ - pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple - -ENV PYTHONPATH /workspace/icefall:$PYTHONPATH +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/conf.py b/docs/source/conf.py index afac002d4..221d9d734 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -74,7 +74,7 @@ html_context = { "github_user": "k2-fsa", "github_repo": "icefall", "github_version": "master", - "conf_py_path": "/icefall/docs/source/", + "conf_py_path": "/docs/source/", } todo_include_todos = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 29491e3dc..be9977ca9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ speech recognition recipes using `k2 `_. :caption: Contents: installation/index + model-export/index recipes/index contributing/index huggingface/index diff --git a/docs/source/model-export/code/export-model-state-dict-pretrained-out.txt b/docs/source/model-export/code/export-model-state-dict-pretrained-out.txt new file mode 100644 index 000000000..8d2d6d34b --- /dev/null +++ b/docs/source/model-export/code/export-model-state-dict-pretrained-out.txt @@ -0,0 +1,21 @@ +2022-10-13 19:09:02,233 INFO [pretrained.py:265] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'encoder_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.21', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4810e00d8738f1a21278b0156a42ff396a2d40ac', 'k2-git-date': 'Fri Oct 7 19:35:03 2022', 'lhotse-version': '1.3.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'onnx-doc-1013', 'icefall-git-sha1': 'c39cba5-dirty', 'icefall-git-date': 'Thu Oct 13 15:17:20 2022', 'icefall-path': '/k2-dev/fangjun/open-source/icefall-master', 'k2-path': '/k2-dev/fangjun/open-source/k2-master/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-jsonl/lhotse/__init__.py', 'hostname': 'de-74279-k2-test-4-0324160024-65bfd8b584-jjlbn', 'IP address': '10.177.74.203'}, 'checkpoint': './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt', 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model', 'method': 'greedy_search', 'sound_files': ['./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav'], 'sample_rate': 16000, 'beam_size': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 8, 'context_size': 2, 'max_sym_per_frame': 1, 'simulate_streaming': False, 'decode_chunk_size': 16, 'left_context': 64, 'dynamic_chunk_training': False, 'causal_convolution': False, 'short_chunk_size': 25, 'num_left_chunks': 4, 'blank_id': 0, 'unk_id': 2, 'vocab_size': 500} +2022-10-13 19:09:02,233 INFO [pretrained.py:271] device: cpu +2022-10-13 19:09:02,233 INFO [pretrained.py:273] Creating model +2022-10-13 19:09:02,612 INFO [train.py:458] Disable giga +2022-10-13 19:09:02,623 INFO [pretrained.py:277] Number of model parameters: 78648040 +2022-10-13 19:09:02,951 INFO [pretrained.py:285] Constructing Fbank computer +2022-10-13 19:09:02,952 INFO [pretrained.py:295] Reading sound files: ['./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav', './icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav'] +2022-10-13 19:09:02,957 INFO [pretrained.py:301] Decoding started +2022-10-13 19:09:06,700 INFO [pretrained.py:329] Using greedy_search +2022-10-13 19:09:06,912 INFO [pretrained.py:388] +./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2022-10-13 19:09:06,912 INFO [pretrained.py:390] Decoding Done diff --git a/docs/source/model-export/export-model-state-dict.rst b/docs/source/model-export/export-model-state-dict.rst new file mode 100644 index 000000000..c3bbd5708 --- /dev/null +++ b/docs/source/model-export/export-model-state-dict.rst @@ -0,0 +1,135 @@ +Export model.state_dict() +========================= + +When to use it +-------------- + +During model training, we save checkpoints periodically to disk. + +A checkpoint contains the following information: + + - ``model.state_dict()`` + - ``optimizer.state_dict()`` + - and some other information related to training + +When we need to resume the training process from some point, we need a checkpoint. +However, if we want to publish the model for inference, then only +``model.state_dict()`` is needed. In this case, we need to strip all other information +except ``model.state_dict()`` to reduce the file size of the published model. + +How to export +------------- + +Every recipe contains a file ``export.py`` that you can use to +export ``model.state_dict()`` by taking some checkpoints as inputs. + +.. hint:: + + Each ``export.py`` contains well-documented usage information. + +In the following, we use +``_ +as an example. + +.. note:: + + The steps for other recipes are almost the same. + +.. code-block:: bash + + cd egs/librispeech/ASR + + ./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +will generate a file ``pruned_transducer_stateless3/exp/pretrained.pt``, which +is a dict containing ``{"model": model.state_dict()}`` saved by ``torch.save()``. + +How to use the exported model +----------------------------- + +For each recipe, we provide pretrained models hosted on huggingface. +You can find links to pretrained models in ``RESULTS.md`` of each dataset. + +In the following, we demonstrate how to use the pretrained model from +``_. + +.. code-block:: bash + + cd egs/librispeech/ASR + + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + +After cloning the repo with ``git lfs``, you will find several files in the folder +``icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp`` +that have a prefix ``pretrained-``. Those files contain ``model.state_dict()`` +exported by the above ``export.py``. + +In each recipe, there is also a file ``pretrained.py``, which can use +``pretrained-xxx.pt`` to decode waves. The following is an example: + +.. code-block:: bash + + cd egs/librispeech/ASR + + ./pruned_transducer_stateless3/pretrained.py \ + --checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \ + --method greedy_search \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \ + ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav + +The above commands show how to use the exported model with ``pretrained.py`` to +decode multiple sound files. Its output is given as follows for reference: + +.. literalinclude:: ./code/export-model-state-dict-pretrained-out.txt + +Use the exported model to run decode.py +--------------------------------------- + +When we publish the model, we always note down its WERs on some test +dataset in ``RESULTS.md``. This section describes how to use the +pretrained model to reproduce the WER. + +.. code-block:: bash + + cd egs/librispeech/ASR + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + + cd icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp + ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt + cd ../.. + +We create a symlink with name ``epoch-9999.pt`` to ``pretrained-iter-1224000-avg-14.pt``, +so that we can pass ``--epoch 9999 --avg 1`` to ``decode.py`` in the following +command: + +.. code-block:: bash + + ./pruned_transducer_stateless3/decode.py \ + --epoch 9999 \ + --avg 1 \ + --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp \ + --lang-dir ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500 \ + --max-duration 600 \ + --decoding-method greedy_search + +You will find the decoding results in +``./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/greedy_search``. + +.. caution:: + + For some recipes, you also need to pass ``--use-averaged-model False`` + to ``decode.py``. The reason is that the exported pretrained model is already + the averaged one. + +.. hint:: + + Before running ``decode.py``, we assume that you have already run + ``prepare.sh`` to prepare the test dataset. diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst new file mode 100644 index 000000000..3dbb8b514 --- /dev/null +++ b/docs/source/model-export/export-ncnn.rst @@ -0,0 +1,12 @@ +Export to ncnn +============== + +We support exporting LSTM transducer models to `ncnn `_. + +Please refer to :ref:`export-model-for-ncnn` for details. + +We also provide ``_ +performing speech recognition using ``ncnn`` with exported models. +It has been tested on Linux, macOS, Windows, and Raspberry Pi. The project is +self-contained and can be statically linked to produce a binary containing +everything needed. diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst new file mode 100644 index 000000000..dd4b3437a --- /dev/null +++ b/docs/source/model-export/export-onnx.rst @@ -0,0 +1,69 @@ +Export to ONNX +============== + +In this section, we describe how to export models to ONNX. + +.. hint:: + + Only non-streaming conformer transducer models are tested. + + +When to use it +-------------- + +It you want to use an inference framework that supports ONNX +to run the pretrained model. + + +How to export +------------- + +We use +``_ +as an example in the following. + +.. code-block:: bash + + cd egs/librispeech/ASR + epoch=14 + avg=2 + + ./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg \ + --onnx 1 + +It will generate the following files inside ``pruned_transducer_stateless3/exp``: + + - ``encoder.onnx`` + - ``decoder.onnx`` + - ``joiner.onnx`` + - ``joiner_encoder_proj.onnx`` + - ``joiner_decoder_proj.onnx`` + +You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode +waves with the generated files: + +.. code-block:: bash + + ./pruned_transducer_stateless3/onnx_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/baz.wav + + +How to use the exported model +----------------------------- + +We also provide ``_ +performing speech recognition using `onnxruntime `_ +with exported models. +It has been tested on Linux, macOS, and Windows. diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst new file mode 100644 index 000000000..a041dc1d5 --- /dev/null +++ b/docs/source/model-export/export-with-torch-jit-script.rst @@ -0,0 +1,58 @@ +.. _export-model-with-torch-jit-script: + +Export model with torch.jit.script() +=================================== + +In this section, we describe how to export a model via +``torch.jit.script()``. + +When to use it +-------------- + +If we want to use our trained model with torchscript, +we can use ``torch.jit.script()``. + +.. hint:: + + See :ref:`export-model-with-torch-jit-trace` + if you want to use ``torch.jit.trace()``. + +How to export +------------- + +We use +``_ +as an example in the following. + +.. code-block:: bash + + cd egs/librispeech/ASR + epoch=14 + avg=1 + + ./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch $epoch \ + --avg $avg \ + --jit 1 + +It will generate a file ``cpu_jit.pt`` in ``pruned_transducer_stateless3/exp``. + +.. caution:: + + Don't be confused by ``cpu`` in ``cpu_jit.pt``. We move all parameters + to CPU before saving it into a ``pt`` file; that's why we use ``cpu`` + in the filename. + +How to use the exported model +----------------------------- + +Please refer to the following pages for usage: + +- ``_ +- ``_ +- ``_ +- ``_ +- ``_ +- ``_ diff --git a/docs/source/model-export/export-with-torch-jit-trace.rst b/docs/source/model-export/export-with-torch-jit-trace.rst new file mode 100644 index 000000000..506459909 --- /dev/null +++ b/docs/source/model-export/export-with-torch-jit-trace.rst @@ -0,0 +1,69 @@ +.. _export-model-with-torch-jit-trace: + +Export model with torch.jit.trace() +=================================== + +In this section, we describe how to export a model via +``torch.jit.trace()``. + +When to use it +-------------- + +If we want to use our trained model with torchscript, +we can use ``torch.jit.trace()``. + +.. hint:: + + See :ref:`export-model-with-torch-jit-script` + if you want to use ``torch.jit.script()``. + +How to export +------------- + +We use +``_ +as an example in the following. + +.. code-block:: bash + + iter=468000 + avg=16 + + cd egs/librispeech/ASR + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg \ + --jit-trace 1 + +It will generate three files inside ``lstm_transducer_stateless2/exp``: + + - ``encoder_jit_trace.pt`` + - ``decoder_jit_trace.pt`` + - ``joiner_jit_trace.pt`` + +You can use +``_ +to decode sound files with the following commands: + +.. code-block:: bash + + cd egs/librispeech/ASR + ./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/baz.wav + +How to use the exported models +------------------------------ + +Please refer to +``_ +for its usage in `sherpa `_. +You can also find pretrained models there. diff --git a/docs/source/model-export/index.rst b/docs/source/model-export/index.rst new file mode 100644 index 000000000..9b7a2ee2d --- /dev/null +++ b/docs/source/model-export/index.rst @@ -0,0 +1,14 @@ +Model export +============ + +In this section, we describe various ways to export models. + + + +.. toctree:: + + export-model-state-dict + export-with-torch-jit-trace + export-with-torch-jit-script + export-onnx + export-ncnn diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/aishell/conformer_ctc.rst index 75a2a8eca..72690e102 100644 --- a/docs/source/recipes/aishell/conformer_ctc.rst +++ b/docs/source/recipes/aishell/conformer_ctc.rst @@ -422,7 +422,7 @@ The information of the test sound files is listed below: .. code-block:: bash - $ soxi tmp/icefall_asr_aishell_conformer_ctc/test_wavs/*.wav + $ soxi tmp/icefall_asr_aishell_conformer_ctc/test_waves/*.wav Input File : 'tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav' Channels : 1 @@ -485,9 +485,9 @@ The command to run CTC decoding is: --checkpoint ./tmp/icefall_asr_aishell_conformer_ctc/exp/pretrained.pt \ --tokens-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/tokens.txt \ --method ctc-decoding \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav The output is given below: @@ -529,9 +529,9 @@ The command to run HLG decoding is: --words-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \ --HLG ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \ --method 1best \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav The output is given below: @@ -575,9 +575,9 @@ The command to run HLG decoding + attention decoder rescoring is: --words-file ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/words.txt \ --HLG ./tmp/icefall_asr_aishell_conformer_ctc/data/lang_char/HLG.pt \ --method attention-decoder \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_conformer_ctc/test_waves/BAC009S0764W0123.wav The output is below: diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst index e9b0ea656..275931698 100644 --- a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/aishell/tdnn_lstm_ctc.rst @@ -402,7 +402,7 @@ The information of the test sound files is listed below: .. code-block:: bash - $ soxi tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/*.wav + $ soxi tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/*.wav Input File : 'tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav' Channels : 1 @@ -461,9 +461,9 @@ The command to run HLG decoding is: --words-file ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/words.txt \ --HLG ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/data/lang_phone/HLG.pt \ --method 1best \ - ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_wavs/BAC009S0764W0123.wav + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0121.wav \ + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0122.wav \ + ./tmp/icefall_asr_aishell_tdnn_lstm_ctc/test_waves/BAC009S0764W0123.wav The output is given below: diff --git a/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png new file mode 100644 index 000000000..cc475a45f Binary files /dev/null and b/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png differ diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/librispeech/index.rst index 5fa08ab6b..6c91b6750 100644 --- a/docs/source/recipes/librispeech/index.rst +++ b/docs/source/recipes/librispeech/index.rst @@ -6,3 +6,4 @@ LibriSpeech tdnn_lstm_ctc conformer_ctc + lstm_pruned_stateless_transducer diff --git a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst new file mode 100644 index 000000000..643855cc2 --- /dev/null +++ b/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst @@ -0,0 +1,636 @@ +LSTM Transducer +=============== + +.. hint:: + + Please scroll down to the bottom of this page to find download links + for pretrained models if you don't want to train a model from scratch. + + +This tutorial shows you how to train an LSTM transducer model +with the `LibriSpeech `_ dataset. + +We use pruned RNN-T to compute the loss. + +.. note:: + + You can find the paper about pruned RNN-T at the following address: + + ``_ + +The transducer model consists of 3 parts: + + - Encoder, a.k.a, the transcription network. We use an LSTM model + - Decoder, a.k.a, the prediction network. We use a stateless model consisting of + ``nn.Embedding`` and ``nn.Conv1d`` + - Joiner, a.k.a, the joint network. + +.. caution:: + + Contrary to the conventional RNN-T models, we use a stateless decoder. + That is, it has no recurrent connections. + +.. hint:: + + Since the encoder model is an LSTM, not Transformer/Conformer, the + resulting model is suitable for streaming/online ASR. + + +Which model to use +------------------ + +Currently, there are two folders about LSTM stateless transducer training: + + - ``(1)`` ``_ + + This recipe uses only LibriSpeech during training. + + - ``(2)`` ``_ + + This recipe uses GigaSpeech + LibriSpeech during training. + +``(1)`` and ``(2)`` use the same model architecture. The only difference is that ``(2)`` supports +multi-dataset. Since ``(2)`` uses more data, it has a lower WER than ``(1)`` but it needs +more training time. + +We use ``lstm_transducer_stateless2`` as an example below. + +.. note:: + + You need to download the `GigaSpeech `_ dataset + to run ``(2)``. If you have only ``LibriSpeech`` dataset available, feel free to use ``(1)``. + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + + # If you use (1), you can **skip** the following command + $ ./prepare_giga_speech.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +.. note:: + + We encourage you to read ``./prepare.sh``. + +The data preparation contains several stages. You can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + +.. hint:: + + If you have pre-downloaded the `LibriSpeech `_ + dataset and the `musan `_ dataset, say, + they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify + the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that + ``./prepare.sh`` won't re-download them. + +.. note:: + + All generated files by ``./prepare.sh``, e.g., features, lexicon, etc, + are saved in ``./data`` directory. + +We provide the following YouTube video showing how to run ``./prepare.sh``. + +.. note:: + + To get the latest news of `next-gen Kaldi `_, please subscribe + the following YouTube channel by `Nadira Povey `_: + + ``_ + +.. youtube:: ofEIoJL-mGM + +Training +-------- + +Configurable options +~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./lstm_transducer_stateless2/train.py --help + +shows you the training options that can be passed from the commandline. +The following options are used quite often: + + - ``--full-libri`` + + If it's True, the training part uses all the training data, i.e., + 960 hours. Otherwise, the training part uses only the subset + ``train-clean-100``, which has 100 hours of training data. + + .. CAUTION:: + + The training set is perturbed by speed with two factors: 0.9 and 1.1. + If ``--full-libri`` is True, each epoch actually processes + ``3x960 == 2880`` hours of data. + + - ``--num-epochs`` + + It is the number of epochs to train. For instance, + ``./lstm_transducer_stateless2/train.py --num-epochs 30`` trains for 30 epochs + and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt`` + in the folder ``./lstm_transducer_stateless2/exp``. + + - ``--start-epoch`` + + It's used to resume training. + ``./lstm_transducer_stateless2/train.py --start-epoch 10`` loads the + checkpoint ``./lstm_transducer_stateless2/exp/epoch-9.pt`` and starts + training from epoch 10, based on the state from epoch 9. + + - ``--world-size`` + + It is used for multi-GPU single-machine DDP training. + + - (a) If it is 1, then no DDP training is used. + + - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training. + + The following shows some use cases with it. + + **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and + GPU 2 for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,2" + $ ./lstm_transducer_stateless2/train.py --world-size 2 + + **Use case 2**: You have 4 GPUs and you want to use all of them + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./lstm_transducer_stateless2/train.py --world-size 4 + + **Use case 3**: You have 4 GPUs but you only want to use GPU 3 + for training. You can do the following: + + .. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="3" + $ ./lstm_transducer_stateless2/train.py --world-size 1 + + .. caution:: + + Only multi-GPU single-machine DDP training is implemented at present. + Multi-GPU multi-machine DDP training will be added later. + + - ``--max-duration`` + + It specifies the number of seconds over all utterances in a + batch, before **padding**. + If you encounter CUDA OOM, please reduce it. + + .. HINT:: + + Due to padding, the number of seconds of all utterances in a + batch will usually be larger than ``--max-duration``. + + A larger value for ``--max-duration`` may cause OOM during training, + while a smaller value may increase the training time. You have to + tune it. + + - ``--giga-prob`` + + The probability to select a batch from the ``GigaSpeech`` dataset. + Note: It is available only for ``(2)``. + +Pre-configured options +~~~~~~~~~~~~~~~~~~~~~~ + +There are some training options, e.g., weight decay, +number of warmup steps, results dir, etc, +that are not passed from the commandline. +They are pre-configured by the function ``get_params()`` in +`lstm_transducer_stateless2/train.py `_ + +You don't need to change these pre-configured parameters. If you really need to change +them, please modify ``./lstm_transducer_stateless2/train.py`` directly. + +Training logs +~~~~~~~~~~~~~ + +Training logs and checkpoints are saved in ``lstm_transducer_stateless2/exp``. +You will find the following files in that directory: + + - ``epoch-1.pt``, ``epoch-2.pt``, ... + + These are checkpoint files saved at the end of each epoch, containing model + ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./lstm_transducer_stateless2/train.py --start-epoch 11 + + - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ... + + These are checkpoint files saved every ``--save-every-n`` batches, + containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``checkpoint-436000``, you can use: + + .. code-block:: bash + + $ ./lstm_transducer_stateless2/train.py --start-batch 436000 + + - ``tensorboard/`` + + This folder contains tensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd lstm_transducer_stateless2/exp/tensorboard + $ tensorboard dev upload --logdir . --description "LSTM transducer training for LibriSpeech with icefall" + + It will print something like below: + + .. code-block:: + + TensorFlow installation not found - running with reduced feature set. + Upload started and will continue reading any new data as it's added to the logdir. + + To stop uploading, press Ctrl-C. + + New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/cj2vtPiwQHKN9Q1tx6PTpg/ + + [2022-09-20T15:50:50] Started scanning logdir. + Uploading 4468 scalars... + [2022-09-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects + Listening for new data in logdir... + + Note there is a URL in the above output. Click it and you will see + the following screenshot: + + .. figure:: images/librispeech-lstm-transducer-tensorboard-log.png + :width: 600 + :alt: TensorBoard screenshot + :align: center + :target: https://tensorboard.dev/experiment/lzGnETjwRxC3yghNMd4kPw/ + + TensorBoard screenshot. + + .. hint:: + + If you don't have access to google, you can use the following command + to view the tensorboard log locally: + + .. code-block:: bash + + cd lstm_transducer_stateless2/exp/tensorboard + tensorboard --logdir . --port 6008 + + It will print the following message: + + .. code-block:: + + Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all + TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit) + + Now start your browser and go to ``_ to view the tensorboard + logs. + + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + +Usage example +~~~~~~~~~~~~~ + +You can use the following command to start the training using 8 GPUs: + +.. code-block:: bash + + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + ./lstm_transducer_stateless2/train.py \ + --world-size 8 \ + --num-epochs 35 \ + --start-epoch 1 \ + --full-libri 1 \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 500 \ + --use-fp16 0 \ + --lr-epochs 10 \ + --num-workers 2 \ + --giga-prob 0.9 + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +.. hint:: + + There are two kinds of checkpoints: + + - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end + of each epoch. You can pass ``--epoch`` to + ``lstm_transducer_stateless2/decode.py`` to use them. + + - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved + every ``--save-every-n`` batches. You can pass ``--iter`` to + ``lstm_transducer_stateless2/decode.py`` to use them. + + We suggest that you try both types of checkpoints and choose the one + that produces the lowest WERs. + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./lstm_transducer_stateless2/decode.py --help + +shows the options for decoding. + +The following shows two examples: + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in 17; do + for avg in 1 2; do + ./lstm_transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method $m \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --beam-size 4 + done + done + done + + +.. code-block:: bash + + for m in greedy_search fast_beam_search modified_beam_search; do + for iter in 474000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method $m \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --beam-size 4 + done + done + done + +Export models +------------- + +`lstm_transducer_stateless2/export.py `_ supports exporting checkpoints from ``lstm_transducer_stateless2/exp`` in the following ways. + +Export ``model.state_dict()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoints saved by ``lstm_transducer_stateless2/train.py`` also include +``optimizer.state_dict()``. It is useful for resuming training. But after training, +we are interested only in ``model.state_dict()``. You can use the following +command to extract ``model.state_dict()``. + +.. code-block:: bash + + # Assume that --iter 468000 --avg 16 produces the smallest WER + # (You can get such information after running ./lstm_transducer_stateless2/decode.py) + + iter=468000 + avg=16 + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg + +It will generate a file ``./lstm_transducer_stateless2/exp/pretrained.pt``. + +.. hint:: + + To use the generated ``pretrained.pt`` for ``lstm_transducer_stateless2/decode.py``, + you can run: + + .. code-block:: bash + + cd lstm_transducer_stateless2/exp + ln -s pretrained epoch-9999.pt + + And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to + ``./lstm_transducer_stateless2/decode.py``. + +To use the exported model with ``./lstm_transducer_stateless2/pretrained.py``, you +can run: + +.. code-block:: bash + + ./lstm_transducer_stateless2/pretrained.py \ + --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +Export model using ``torch.jit.trace()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + iter=468000 + avg=16 + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg \ + --jit-trace 1 + +It will generate 3 files: + + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace.pt`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace.pt`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace.pt`` + +To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``: + +.. code-block:: bash + + ./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --encoder-model-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace.pt \ + /path/to/foo.wav \ + /path/to/bar.wav + +.. hint:: + + Please see ``_ + for how to use the exported models in ``sherpa``. + +.. _export-model-for-ncnn: + +Export model for ncnn +~~~~~~~~~~~~~~~~~~~~~ + +We support exporting pretrained LSTM transducer models to +`ncnn `_ using +`pnnx `_. + +First, let us install a modified version of ``ncnn``: + +.. code-block:: bash + + git clone https://github.com/csukuangfj/ncnn + cd ncnn + git submodule update --recursive --init + python3 setup.py bdist_wheel + ls -lh dist/ + pip install ./dist/*.whl + + # now build pnnx + cd tools/pnnx + mkdir build + cd build + make -j4 + export PATH=$PWD/src:$PATH + + ./src/pnnx + +.. note:: + + We assume that you have added the path to the binary ``pnnx`` to the + environment variable ``PATH``. + +Second, let us export the model using ``torch.jit.trace()`` that is suitable +for ``pnnx``: + +.. code-block:: bash + + iter=468000 + avg=16 + + ./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --iter $iter \ + --avg $avg \ + --pnnx 1 + +It will generate 3 files: + + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt`` + +Third, convert torchscript model to ``ncnn`` format: + +.. code-block:: + + pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt + pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt + pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt + +It will generate the following files: + + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param`` + - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param`` + - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param`` + - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin`` + +To use the above generated files, run: + +.. code-block:: bash + + ./lstm_transducer_stateless2/ncnn-decode.py \ + --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ + /path/to/foo.wav + +.. code-block:: bash + + ./lstm_transducer_stateless2/streaming-ncnn-decode.py \ + --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \ + /path/to/foo.wav + +To use the above generated files in C++, please see +``_ + +It is able to generate a static linked executable that can be run on Linux, Windows, +macOS, Raspberry Pi, etc, without external dependencies. + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following links: + + - ``_ + + - ``_ + + See ``_ + for the details of the above pretrained models + +You can find more usages of the pretrained models in +``_ diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index 1e3e7b492..cb7205e51 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -248,7 +248,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) if not self.normalize_before: src = self.norm_conv(src) @@ -879,11 +881,16 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -897,6 +904,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) x = self.activation(self.norm(x)) diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index a3e7f98e3..751b7d5b5 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -335,7 +335,7 @@ def decode_dataset( lexicon: Lexicon, sos_id: int, eos_id: int, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -410,7 +410,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.method == "attention-decoder": # Set it to False since there are too many logs. diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py index 1e3e7b492..cb7205e51 100644 --- a/egs/aishell/ASR/conformer_mmi/conformer.py +++ b/egs/aishell/ASR/conformer_mmi/conformer.py @@ -248,7 +248,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) if not self.normalize_before: src = self.norm_conv(src) @@ -879,11 +881,16 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -897,6 +904,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) x = self.activation(self.norm(x)) diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index ac68b61e7..4db367e36 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -347,7 +347,7 @@ def decode_dataset( lexicon: Lexicon, sos_id: int, eos_id: int, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -422,7 +422,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.method == "attention-decoder": # Set it to False since there are too many logs. diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index 3a68e8765..a12934d55 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -326,7 +326,7 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -396,7 +396,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 3268c8bb2..d159e420b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -340,7 +340,7 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -410,7 +410,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 45c1c4ec1..66b734fc4 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -208,7 +208,7 @@ def decode_dataset( model: nn.Module, HLG: k2.Fsa, lexicon: Lexicon, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -274,7 +274,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index 66eb3eb63..64114253d 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -246,7 +246,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) if not self.normalize_before: src = self.norm_conv(src) @@ -877,11 +879,16 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -895,6 +902,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index d78821b95..780b0c4bb 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -264,7 +264,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -328,7 +328,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index a9dca995f..ea3f94fd8 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -304,7 +304,7 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -374,7 +374,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 9e827e1d1..65fcda873 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -308,7 +308,7 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -378,7 +378,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 7d6f6f6d5..915737f4a 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -478,7 +478,7 @@ def decode_dataset( lexicon: Lexicon, graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -547,7 +547,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 739dfc4a1..14e44c7d9 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -342,7 +342,7 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -410,7 +410,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 65fc74728..6358fe970 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -331,7 +331,7 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -399,7 +399,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore new file mode 100644 index 000000000..c0a162e20 --- /dev/null +++ b/egs/csj/ASR/.gitignore @@ -0,0 +1,7 @@ +librispeech_*.* +todelete* +lang* +notify_tg.py +finetune_* +misc.ini +.vscode/* \ No newline at end of file diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py new file mode 100644 index 000000000..994dedbdd --- /dev/null +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import os +from itertools import islice +from pathlib import Path +from random import Random +from typing import List, Tuple + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + # fmt: off + # See the following for why LilcomChunkyWriter is preferred + # https://github.com/k2-fsa/icefall/pull/404 + # https://github.com/lhotse-speech/lhotse/pull/527 + # fmt: on + LilcomChunkyWriter, + RecordingSet, + SupervisionSet, +) + +ARGPARSE_DESCRIPTION = """ +This script follows the espnet method of splitting the remaining core+noncore +utterances into valid and train cutsets at an index which is by default 4000. + +In other words, the core+noncore utterances are shuffled, where 4000 utterances +of the shuffled set go to the `valid` cutset and are not subject to speed +perturbation. The remaining utterances become the `train` cutset and are speed- +perturbed (0.9x, 1.0x, 1.1x). + +""" + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +RNG_SEED = 42 + + +def make_cutset_blueprints( + manifest_dir: Path, + split: int, +) -> List[Tuple[str, CutSet]]: + + cut_sets = [] + # Create eval datasets + logging.info("Creating eval cuts.") + for i in range(1, 4): + cut_set = CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / f"csj_recordings_eval{i}.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" + ), + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_sets.append((f"eval{i}", cut_set)) + + # Create train and valid cuts + logging.info( + "Loading, trimming, and shuffling the remaining core+noncore cuts." + ) + recording_set = RecordingSet.from_file( + manifest_dir / "csj_recordings_core.jsonl.gz" + ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") + supervision_set = SupervisionSet.from_file( + manifest_dir / "csj_supervisions_core.jsonl.gz" + ) + SupervisionSet.from_file( + manifest_dir / "csj_supervisions_noncore.jsonl.gz" + ) + + cut_set = CutSet.from_manifests( + recordings=recording_set, + supervisions=supervision_set, + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_set = cut_set.shuffle(Random(RNG_SEED)) + + logging.info( + "Creating valid and train cuts from core and noncore," + f"split at {split}." + ) + valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) + + train_set = CutSet.from_cuts(islice(cut_set, split, None)) + train_set = ( + train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) + ) + + cut_sets.extend([("valid", valid_set), ("train", train_set)]) + + return cut_sets + + +def get_args(): + parser = argparse.ArgumentParser( + description=ARGPARSE_DESCRIPTION, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "--fbank-dir", type=Path, help="Path to save fbank features" + ) + parser.add_argument( + "--split", type=int, default=4000, help="Split at this index" + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + extractor = Fbank(FbankConfig(num_mel_bins=80)) + num_jobs = min(16, os.cpu_count()) + + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + if (args.fbank_dir / ".done").exists(): + logging.info( + "Previous fbank computed for CSJ found. " + f"Delete {args.fbank_dir / '.done'} to allow recomputing fbank." + ) + return + else: + cut_sets = make_cutset_blueprints(args.manifest_dir, args.split) + for part, cut_set in cut_sets: + logging.info(f"Processing {part}") + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + num_jobs=num_jobs, + storage_path=(args.fbank_dir / f"feats_{part}").as_posix(), + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz") + + logging.info("All fbank computed for CSJ.") + (args.fbank_dir / ".done").touch() + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini new file mode 100644 index 000000000..eb70673de --- /dev/null +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -0,0 +1,321 @@ +; # This section is ignored if this file is not supplied as the first config file to +; # lhotse prepare csj +[SEGMENTS] +; # Allowed period of nonverbal noise. If exceeded, a new segment is created. +gap = 0.5 +; # Maximum length of segment (s). +maxlen = 10 +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +minlen = 0.02 +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = + +[CONSTANTS] +; # Name of this mode +MODE = disfluent +; # Suffixes to use after the word surface (no longer used) +MORPH = pos1 cForm cType2 pos2 +; # Used to differentiate between A tag and A_num tag +JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . +; # Dummy character to delineate multiline words +PLUS = + + +[DECISIONS] +; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 +; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries + +; # フィラー、感情表出系感動詞 +; # 0 to remain, 1 to delete +; # Example: '(F ぎょっ)' +F = 0 +; # Example: '(L (F ン))', '比べ(F えー)る' +F^ = 0 +; # 言い直し、いいよどみなどによる語断片 +; # 0 to remain, 1 to delete +; # Example: '(D だ)(D だいが) 大学の学部の会議' +D = 0 +; # Example: '(L (D ドゥ)+(D ヒ))' +D^ = 0 +; # 助詞、助動詞、接辞の言い直し +; # 0 to remain, 1 to delete +; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' +D2 = 0 +; # Example: '(X (D2 ノ))' +D2^ = 0 +; # 聞き取りや語彙の判断に自信がない場合 +; # 0 to remain, 1 to delete +; # Example: (? 字数) の +; # If no option: empty string is returned regardless of output +; # Example: '(?) で' +? = 0 +; # Example: '(D (? すー))+そう+です+よ+ね' +?^ = 0 +; # タグ?で、値は複数の候補が想定される場合 +; # 0 for main guess with matching morph info, 1 for second guess +; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' +?, = 0 +; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' +?,^ = 0 +; # 音や言葉に関するメタ的な引用 +; # 0 to remain, 1 to delete +; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' +M = 0 +; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' +M^ = 0 +; # 外国語や古語、方言など +; # 0 to remain, 1 to delete +; # Example: '(O ザッツファイン)' +O = 0 +; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' +O^ = 0 +; # 講演者の名前、差別語、誹謗中傷など +; # 0 to remain, 1 to delete +; # Example: '国語研の (R ××) です' +R = 0 +R^ = 0 +; # 非朗読対象発話(朗読における言い間違い等) +; # 0 to remain, 1 to delete +; # Example: '(X 実際は) 実際には' +X = 0 +; # Example: '(L (X (D2 ニ)))' +X^ = 0 +; # アルファベットや算用数字、記号の表記 +; # 0 to use Japanese form, 1 to use alphabet form +; # Example: '(A シーディーアール;CD-R)' +A = 1 +; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') +A^ = 1 +; # タグAで、単語は算用数字の場合 +; # 0 to use Japanese form, 1 to use Arabic numerals +; # Example: (A 二千;2000) +A_num = eval:self.notag +A_num^ = eval:self.notag +; # 何らかの原因で漢字表記できなくなった場合 +; # 0 to use broken form, 1 to use orthodox form +; # Example: '(K たち (F えー) ばな;橘)' +K = 1 +; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' +K^ = 1 +; # 転訛、発音の怠けなど、一時的な発音エラー +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(W ギーツ;ギジュツ)' +W = 1 +; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' +W^ = 1 +; # 語の読みに関する知識レベルのいい間違い +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(B シブタイ;ジュータイ)' +B = 0 +; # Example: 'データー(B カズ;スー)' +B^ = 0 +; # 笑いながら発話 +; # 0 to remain, 1 to delete +; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' +笑 = 0 +; # Example: 'コク(笑 サイ+(D オン))', +笑^ = 0 +; # 泣きながら発話 +; # 0 to remain, 1 to delete +; # Example: '(泣 ドンナニ)' +泣 = 0 +泣^ = 0 +; # 咳をしながら発話 +; # 0 to remain, 1 to delete +; # Example: 'シャ(咳 リン) ノ' +咳 = 0 +; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' +咳^ = 0 +; # ささやき声や独り言などの小さな声 +; # 0 to remain, 1 to delete +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +L = 0 +; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' +L^ = 0 + +[REPLACEMENTS] +; # ボーカルフライなどで母音が同定できない場合 + = +; # 「うん/うーん/ふーん」の音の特定が困難な場合 + = +; # 非語彙的な母音の引き延ばし + = +; # 非語彙的な子音の引き延ばし + = +; # 言語音と独立に講演者の笑いが生じている場合 +<笑> = +; # 言語音と独立に講演者の咳が生じている場合 +<咳> = +; # 言語音と独立に講演者の息が生じている場合 +<息> = +; # 講演者の泣き声 +<泣> = +; # 聴衆(司会者なども含む)の発話 +<フロア発話> = +; # 聴衆の笑い +<フロア笑> = +; # 聴衆の拍手 +<拍手> = +; # 講演者が発表中に用いたデモンストレーションの音声 +<デモ> = +; # 学会講演に発表時間を知らせるためにならすベルの音 +<ベル> = +; # 転記単位全体が再度読み直された場合 +<朗読間違い> = +; # 上記以外の音で特に目立った音 +<雑音> = +; # 0.2秒以上のポーズ +

= +; # Redacted information, for R +; # It is \x00D7 multiplication sign, not your normal 'x' +× = × + +[FIELDS] +; # Time information for segment +time = 3 +; # Word surface +surface = 5 +; # Word surface root form without CSJ tags +notag = 9 +; # Part Of Speech +pos1 = 11 +; # Conjugated Form +cForm = 12 +; # Conjugation Type +cType1 = 13 +; # Subcategory of POS +pos2 = 14 +; # Euphonic Change / Subcategory of Conjugation Type +cType2 = 15 +; # Other information +other = 16 +; # Pronunciation for lexicon +pron = 10 +; # Speaker ID +spk_id = 2 + +[KATAKANA2ROMAJI] +ア = 'a +イ = 'i +ウ = 'u +エ = 'e +オ = 'o +カ = ka +キ = ki +ク = ku +ケ = ke +コ = ko +ガ = ga +ギ = gi +グ = gu +ゲ = ge +ゴ = go +サ = sa +シ = si +ス = su +セ = se +ソ = so +ザ = za +ジ = zi +ズ = zu +ゼ = ze +ゾ = zo +タ = ta +チ = ti +ツ = tu +テ = te +ト = to +ダ = da +ヂ = di +ヅ = du +デ = de +ド = do +ナ = na +ニ = ni +ヌ = nu +ネ = ne +ノ = no +ハ = ha +ヒ = hi +フ = hu +ヘ = he +ホ = ho +バ = ba +ビ = bi +ブ = bu +ベ = be +ボ = bo +パ = pa +ピ = pi +プ = pu +ペ = pe +ポ = po +マ = ma +ミ = mi +ム = mu +メ = me +モ = mo +ヤ = ya +ユ = yu +ヨ = yo +ラ = ra +リ = ri +ル = ru +レ = re +ロ = ro +ワ = wa +ヰ = we +ヱ = wi +ヲ = wo +ン = ŋ +ッ = q +ー = - +キャ = kǐa +キュ = kǐu +キョ = kǐo +ギャ = gǐa +ギュ = gǐu +ギョ = gǐo +シャ = sǐa +シュ = sǐu +ショ = sǐo +ジャ = zǐa +ジュ = zǐu +ジョ = zǐo +チャ = tǐa +チュ = tǐu +チョ = tǐo +ヂャ = dǐa +ヂュ = dǐu +ヂョ = dǐo +ニャ = nǐa +ニュ = nǐu +ニョ = nǐo +ヒャ = hǐa +ヒュ = hǐu +ヒョ = hǐo +ビャ = bǐa +ビュ = bǐu +ビョ = bǐo +ピャ = pǐa +ピュ = pǐu +ピョ = pǐo +ミャ = mǐa +ミュ = mǐu +ミョ = mǐo +リャ = rǐa +リュ = rǐu +リョ = rǐo +ァ = a +ィ = i +ゥ = u +ェ = e +ォ = o +ヮ = ʍ +ヴ = vu +ャ = ǐa +ュ = ǐu +ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini new file mode 100644 index 000000000..5d22f9eb8 --- /dev/null +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -0,0 +1,321 @@ +; # This section is ignored if this file is not supplied as the first config file to +; # lhotse prepare csj +[SEGMENTS] +; # Allowed period of nonverbal noise. If exceeded, a new segment is created. +gap = 0.5 +; # Maximum length of segment (s). +maxlen = 10 +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +minlen = 0.02 +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = + +[CONSTANTS] +; # Name of this mode +MODE = fluent +; # Suffixes to use after the word surface (no longer used) +MORPH = pos1 cForm cType2 pos2 +; # Used to differentiate between A tag and A_num tag +JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . +; # Dummy character to delineate multiline words +PLUS = + + +[DECISIONS] +; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 +; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries + +; # フィラー、感情表出系感動詞 +; # 0 to remain, 1 to delete +; # Example: '(F ぎょっ)' +F = 1 +; # Example: '(L (F ン))', '比べ(F えー)る' +F^ = 1 +; # 言い直し、いいよどみなどによる語断片 +; # 0 to remain, 1 to delete +; # Example: '(D だ)(D だいが) 大学の学部の会議' +D = 1 +; # Example: '(L (D ドゥ)+(D ヒ))' +D^ = 1 +; # 助詞、助動詞、接辞の言い直し +; # 0 to remain, 1 to delete +; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' +D2 = 1 +; # Example: '(X (D2 ノ))' +D2^ = 1 +; # 聞き取りや語彙の判断に自信がない場合 +; # 0 to remain, 1 to delete +; # Example: (? 字数) の +; # If no option: empty string is returned regardless of output +; # Example: '(?) で' +? = 0 +; # Example: '(D (? すー))+そう+です+よ+ね' +?^ = 0 +; # タグ?で、値は複数の候補が想定される場合 +; # 0 for main guess with matching morph info, 1 for second guess +; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' +?, = 0 +; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' +?,^ = 0 +; # 音や言葉に関するメタ的な引用 +; # 0 to remain, 1 to delete +; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' +M = 0 +; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' +M^ = 0 +; # 外国語や古語、方言など +; # 0 to remain, 1 to delete +; # Example: '(O ザッツファイン)' +O = 0 +; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' +O^ = 0 +; # 講演者の名前、差別語、誹謗中傷など +; # 0 to remain, 1 to delete +; # Example: '国語研の (R ××) です' +R = 0 +R^ = 0 +; # 非朗読対象発話(朗読における言い間違い等) +; # 0 to remain, 1 to delete +; # Example: '(X 実際は) 実際には' +X = 0 +; # Example: '(L (X (D2 ニ)))' +X^ = 0 +; # アルファベットや算用数字、記号の表記 +; # 0 to use Japanese form, 1 to use alphabet form +; # Example: '(A シーディーアール;CD-R)' +A = 1 +; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') +A^ = 1 +; # タグAで、単語は算用数字の場合 +; # 0 to use Japanese form, 1 to use Arabic numerals +; # Example: (A 二千;2000) +A_num = eval:self.notag +A_num^ = eval:self.notag +; # 何らかの原因で漢字表記できなくなった場合 +; # 0 to use broken form, 1 to use orthodox form +; # Example: '(K たち (F えー) ばな;橘)' +K = 1 +; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' +K^ = 1 +; # 転訛、発音の怠けなど、一時的な発音エラー +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(W ギーツ;ギジュツ)' +W = 1 +; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' +W^ = 1 +; # 語の読みに関する知識レベルのいい間違い +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(B シブタイ;ジュータイ)' +B = 0 +; # Example: 'データー(B カズ;スー)' +B^ = 0 +; # 笑いながら発話 +; # 0 to remain, 1 to delete +; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' +笑 = 0 +; # Example: 'コク(笑 サイ+(D オン))', +笑^ = 0 +; # 泣きながら発話 +; # 0 to remain, 1 to delete +; # Example: '(泣 ドンナニ)' +泣 = 0 +泣^ = 0 +; # 咳をしながら発話 +; # 0 to remain, 1 to delete +; # Example: 'シャ(咳 リン) ノ' +咳 = 0 +; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' +咳^ = 0 +; # ささやき声や独り言などの小さな声 +; # 0 to remain, 1 to delete +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +L = 0 +; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' +L^ = 0 + +[REPLACEMENTS] +; # ボーカルフライなどで母音が同定できない場合 + = +; # 「うん/うーん/ふーん」の音の特定が困難な場合 + = +; # 非語彙的な母音の引き延ばし + = +; # 非語彙的な子音の引き延ばし + = +; # 言語音と独立に講演者の笑いが生じている場合 +<笑> = +; # 言語音と独立に講演者の咳が生じている場合 +<咳> = +; # 言語音と独立に講演者の息が生じている場合 +<息> = +; # 講演者の泣き声 +<泣> = +; # 聴衆(司会者なども含む)の発話 +<フロア発話> = +; # 聴衆の笑い +<フロア笑> = +; # 聴衆の拍手 +<拍手> = +; # 講演者が発表中に用いたデモンストレーションの音声 +<デモ> = +; # 学会講演に発表時間を知らせるためにならすベルの音 +<ベル> = +; # 転記単位全体が再度読み直された場合 +<朗読間違い> = +; # 上記以外の音で特に目立った音 +<雑音> = +; # 0.2秒以上のポーズ +

= +; # Redacted information, for R +; # It is \x00D7 multiplication sign, not your normal 'x' +× = × + +[FIELDS] +; # Time information for segment +time = 3 +; # Word surface +surface = 5 +; # Word surface root form without CSJ tags +notag = 9 +; # Part Of Speech +pos1 = 11 +; # Conjugated Form +cForm = 12 +; # Conjugation Type +cType1 = 13 +; # Subcategory of POS +pos2 = 14 +; # Euphonic Change / Subcategory of Conjugation Type +cType2 = 15 +; # Other information +other = 16 +; # Pronunciation for lexicon +pron = 10 +; # Speaker ID +spk_id = 2 + +[KATAKANA2ROMAJI] +ア = 'a +イ = 'i +ウ = 'u +エ = 'e +オ = 'o +カ = ka +キ = ki +ク = ku +ケ = ke +コ = ko +ガ = ga +ギ = gi +グ = gu +ゲ = ge +ゴ = go +サ = sa +シ = si +ス = su +セ = se +ソ = so +ザ = za +ジ = zi +ズ = zu +ゼ = ze +ゾ = zo +タ = ta +チ = ti +ツ = tu +テ = te +ト = to +ダ = da +ヂ = di +ヅ = du +デ = de +ド = do +ナ = na +ニ = ni +ヌ = nu +ネ = ne +ノ = no +ハ = ha +ヒ = hi +フ = hu +ヘ = he +ホ = ho +バ = ba +ビ = bi +ブ = bu +ベ = be +ボ = bo +パ = pa +ピ = pi +プ = pu +ペ = pe +ポ = po +マ = ma +ミ = mi +ム = mu +メ = me +モ = mo +ヤ = ya +ユ = yu +ヨ = yo +ラ = ra +リ = ri +ル = ru +レ = re +ロ = ro +ワ = wa +ヰ = we +ヱ = wi +ヲ = wo +ン = ŋ +ッ = q +ー = - +キャ = kǐa +キュ = kǐu +キョ = kǐo +ギャ = gǐa +ギュ = gǐu +ギョ = gǐo +シャ = sǐa +シュ = sǐu +ショ = sǐo +ジャ = zǐa +ジュ = zǐu +ジョ = zǐo +チャ = tǐa +チュ = tǐu +チョ = tǐo +ヂャ = dǐa +ヂュ = dǐu +ヂョ = dǐo +ニャ = nǐa +ニュ = nǐu +ニョ = nǐo +ヒャ = hǐa +ヒュ = hǐu +ヒョ = hǐo +ビャ = bǐa +ビュ = bǐu +ビョ = bǐo +ピャ = pǐa +ピュ = pǐu +ピョ = pǐo +ミャ = mǐa +ミュ = mǐu +ミョ = mǐo +リャ = rǐa +リュ = rǐu +リョ = rǐo +ァ = a +ィ = i +ゥ = u +ェ = e +ォ = o +ヮ = ʍ +ヴ = vu +ャ = ǐa +ュ = ǐu +ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini new file mode 100644 index 000000000..2613c3409 --- /dev/null +++ b/egs/csj/ASR/local/conf/number.ini @@ -0,0 +1,321 @@ +; # This section is ignored if this file is not supplied as the first config file to +; # lhotse prepare csj +[SEGMENTS] +; # Allowed period of nonverbal noise. If exceeded, a new segment is created. +gap = 0.5 +; # Maximum length of segment (s). +maxlen = 10 +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +minlen = 0.02 +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = + +[CONSTANTS] +; # Name of this mode +MODE = number +; # Suffixes to use after the word surface (no longer used) +MORPH = pos1 cForm cType2 pos2 +; # Used to differentiate between A tag and A_num tag +JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . +; # Dummy character to delineate multiline words +PLUS = + + +[DECISIONS] +; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 +; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries + +; # フィラー、感情表出系感動詞 +; # 0 to remain, 1 to delete +; # Example: '(F ぎょっ)' +F = 1 +; # Example: '(L (F ン))', '比べ(F えー)る' +F^ = 1 +; # 言い直し、いいよどみなどによる語断片 +; # 0 to remain, 1 to delete +; # Example: '(D だ)(D だいが) 大学の学部の会議' +D = 1 +; # Example: '(L (D ドゥ)+(D ヒ))' +D^ = 1 +; # 助詞、助動詞、接辞の言い直し +; # 0 to remain, 1 to delete +; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' +D2 = 1 +; # Example: '(X (D2 ノ))' +D2^ = 1 +; # 聞き取りや語彙の判断に自信がない場合 +; # 0 to remain, 1 to delete +; # Example: (? 字数) の +; # If no option: empty string is returned regardless of output +; # Example: '(?) で' +? = 0 +; # Example: '(D (? すー))+そう+です+よ+ね' +?^ = 0 +; # タグ?で、値は複数の候補が想定される場合 +; # 0 for main guess with matching morph info, 1 for second guess +; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' +?, = 0 +; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' +?,^ = 0 +; # 音や言葉に関するメタ的な引用 +; # 0 to remain, 1 to delete +; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' +M = 0 +; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' +M^ = 0 +; # 外国語や古語、方言など +; # 0 to remain, 1 to delete +; # Example: '(O ザッツファイン)' +O = 0 +; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' +O^ = 0 +; # 講演者の名前、差別語、誹謗中傷など +; # 0 to remain, 1 to delete +; # Example: '国語研の (R ××) です' +R = 0 +R^ = 0 +; # 非朗読対象発話(朗読における言い間違い等) +; # 0 to remain, 1 to delete +; # Example: '(X 実際は) 実際には' +X = 0 +; # Example: '(L (X (D2 ニ)))' +X^ = 0 +; # アルファベットや算用数字、記号の表記 +; # 0 to use Japanese form, 1 to use alphabet form +; # Example: '(A シーディーアール;CD-R)' +A = 1 +; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') +A^ = 1 +; # タグAで、単語は算用数字の場合 +; # 0 to use Japanese form, 1 to use Arabic numerals +; # Example: (A 二千;2000) +A_num = 1 +A_num^ = 1 +; # 何らかの原因で漢字表記できなくなった場合 +; # 0 to use broken form, 1 to use orthodox form +; # Example: '(K たち (F えー) ばな;橘)' +K = 1 +; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' +K^ = 1 +; # 転訛、発音の怠けなど、一時的な発音エラー +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(W ギーツ;ギジュツ)' +W = 1 +; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' +W^ = 1 +; # 語の読みに関する知識レベルのいい間違い +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(B シブタイ;ジュータイ)' +B = 0 +; # Example: 'データー(B カズ;スー)' +B^ = 0 +; # 笑いながら発話 +; # 0 to remain, 1 to delete +; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' +笑 = 0 +; # Example: 'コク(笑 サイ+(D オン))', +笑^ = 0 +; # 泣きながら発話 +; # 0 to remain, 1 to delete +; # Example: '(泣 ドンナニ)' +泣 = 0 +泣^ = 0 +; # 咳をしながら発話 +; # 0 to remain, 1 to delete +; # Example: 'シャ(咳 リン) ノ' +咳 = 0 +; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' +咳^ = 0 +; # ささやき声や独り言などの小さな声 +; # 0 to remain, 1 to delete +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +L = 0 +; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' +L^ = 0 + +[REPLACEMENTS] +; # ボーカルフライなどで母音が同定できない場合 + = +; # 「うん/うーん/ふーん」の音の特定が困難な場合 + = +; # 非語彙的な母音の引き延ばし + = +; # 非語彙的な子音の引き延ばし + = +; # 言語音と独立に講演者の笑いが生じている場合 +<笑> = +; # 言語音と独立に講演者の咳が生じている場合 +<咳> = +; # 言語音と独立に講演者の息が生じている場合 +<息> = +; # 講演者の泣き声 +<泣> = +; # 聴衆(司会者なども含む)の発話 +<フロア発話> = +; # 聴衆の笑い +<フロア笑> = +; # 聴衆の拍手 +<拍手> = +; # 講演者が発表中に用いたデモンストレーションの音声 +<デモ> = +; # 学会講演に発表時間を知らせるためにならすベルの音 +<ベル> = +; # 転記単位全体が再度読み直された場合 +<朗読間違い> = +; # 上記以外の音で特に目立った音 +<雑音> = +; # 0.2秒以上のポーズ +

= +; # Redacted information, for R +; # It is \x00D7 multiplication sign, not your normal 'x' +× = × + +[FIELDS] +; # Time information for segment +time = 3 +; # Word surface +surface = 5 +; # Word surface root form without CSJ tags +notag = 9 +; # Part Of Speech +pos1 = 11 +; # Conjugated Form +cForm = 12 +; # Conjugation Type +cType1 = 13 +; # Subcategory of POS +pos2 = 14 +; # Euphonic Change / Subcategory of Conjugation Type +cType2 = 15 +; # Other information +other = 16 +; # Pronunciation for lexicon +pron = 10 +; # Speaker ID +spk_id = 2 + +[KATAKANA2ROMAJI] +ア = 'a +イ = 'i +ウ = 'u +エ = 'e +オ = 'o +カ = ka +キ = ki +ク = ku +ケ = ke +コ = ko +ガ = ga +ギ = gi +グ = gu +ゲ = ge +ゴ = go +サ = sa +シ = si +ス = su +セ = se +ソ = so +ザ = za +ジ = zi +ズ = zu +ゼ = ze +ゾ = zo +タ = ta +チ = ti +ツ = tu +テ = te +ト = to +ダ = da +ヂ = di +ヅ = du +デ = de +ド = do +ナ = na +ニ = ni +ヌ = nu +ネ = ne +ノ = no +ハ = ha +ヒ = hi +フ = hu +ヘ = he +ホ = ho +バ = ba +ビ = bi +ブ = bu +ベ = be +ボ = bo +パ = pa +ピ = pi +プ = pu +ペ = pe +ポ = po +マ = ma +ミ = mi +ム = mu +メ = me +モ = mo +ヤ = ya +ユ = yu +ヨ = yo +ラ = ra +リ = ri +ル = ru +レ = re +ロ = ro +ワ = wa +ヰ = we +ヱ = wi +ヲ = wo +ン = ŋ +ッ = q +ー = - +キャ = kǐa +キュ = kǐu +キョ = kǐo +ギャ = gǐa +ギュ = gǐu +ギョ = gǐo +シャ = sǐa +シュ = sǐu +ショ = sǐo +ジャ = zǐa +ジュ = zǐu +ジョ = zǐo +チャ = tǐa +チュ = tǐu +チョ = tǐo +ヂャ = dǐa +ヂュ = dǐu +ヂョ = dǐo +ニャ = nǐa +ニュ = nǐu +ニョ = nǐo +ヒャ = hǐa +ヒュ = hǐu +ヒョ = hǐo +ビャ = bǐa +ビュ = bǐu +ビョ = bǐo +ピャ = pǐa +ピュ = pǐu +ピョ = pǐo +ミャ = mǐa +ミュ = mǐu +ミョ = mǐo +リャ = rǐa +リュ = rǐu +リョ = rǐo +ァ = a +ィ = i +ゥ = u +ェ = e +ォ = o +ヮ = ʍ +ヴ = vu +ャ = ǐa +ュ = ǐu +ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini new file mode 100644 index 000000000..8ba451dd5 --- /dev/null +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -0,0 +1,322 @@ +; # This section is ignored if this file is not supplied as the first config file to +; # lhotse prepare csj +[SEGMENTS] +; # Allowed period of nonverbal noise. If exceeded, a new segment is created. +gap = 0.5 +; # Maximum length of segment (s). +maxlen = 10 +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +minlen = 0.02 +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = + +[CONSTANTS] +; # Name of this mode +; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf +MODE = symbol +; # Suffixes to use after the word surface (no longer used) +MORPH = pos1 cForm cType2 pos2 +; # Used to differentiate between A tag and A_num tag +JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . +; # Dummy character to delineate multiline words +PLUS = + + +[DECISIONS] +; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 +; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries + +; # フィラー、感情表出系感動詞 +; # 0 to remain, 1 to delete +; # Example: '(F ぎょっ)' +F = # +; # Example: '(L (F ン))', '比べ(F えー)る' +F^ = # +; # 言い直し、いいよどみなどによる語断片 +; # 0 to remain, 1 to delete +; # Example: '(D だ)(D だいが) 大学の学部の会議' +D = @ +; # Example: '(L (D ドゥ)+(D ヒ))' +D^ = @ +; # 助詞、助動詞、接辞の言い直し +; # 0 to remain, 1 to delete +; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' +D2 = @ +; # Example: '(X (D2 ノ))' +D2^ = @ +; # 聞き取りや語彙の判断に自信がない場合 +; # 0 to remain, 1 to delete +; # Example: (? 字数) の +; # If no option: empty string is returned regardless of output +; # Example: '(?) で' +? = 0 +; # Example: '(D (? すー))+そう+です+よ+ね' +?^ = 0 +; # タグ?で、値は複数の候補が想定される場合 +; # 0 for main guess with matching morph info, 1 for second guess +; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' +?, = 0 +; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' +?,^ = 0 +; # 音や言葉に関するメタ的な引用 +; # 0 to remain, 1 to delete +; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' +M = 0 +; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' +M^ = 0 +; # 外国語や古語、方言など +; # 0 to remain, 1 to delete +; # Example: '(O ザッツファイン)' +O = 0 +; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' +O^ = 0 +; # 講演者の名前、差別語、誹謗中傷など +; # 0 to remain, 1 to delete +; # Example: '国語研の (R ××) です' +R = 0 +R^ = 0 +; # 非朗読対象発話(朗読における言い間違い等) +; # 0 to remain, 1 to delete +; # Example: '(X 実際は) 実際には' +X = 0 +; # Example: '(L (X (D2 ニ)))' +X^ = 0 +; # アルファベットや算用数字、記号の表記 +; # 0 to use Japanese form, 1 to use alphabet form +; # Example: '(A シーディーアール;CD-R)' +A = 1 +; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') +A^ = 1 +; # タグAで、単語は算用数字の場合 +; # 0 to use Japanese form, 1 to use Arabic numerals +; # Example: (A 二千;2000) +A_num = eval:self.notag +A_num^ = eval:self.notag +; # 何らかの原因で漢字表記できなくなった場合 +; # 0 to use broken form, 1 to use orthodox form +; # Example: '(K たち (F えー) ばな;橘)' +K = 1 +; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' +K^ = 1 +; # 転訛、発音の怠けなど、一時的な発音エラー +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(W ギーツ;ギジュツ)' +W = 1 +; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' +W^ = 1 +; # 語の読みに関する知識レベルのいい間違い +; # 0 to use wrong form, 1 to use orthodox form +; # Example: '(B シブタイ;ジュータイ)' +B = 0 +; # Example: 'データー(B カズ;スー)' +B^ = 0 +; # 笑いながら発話 +; # 0 to remain, 1 to delete +; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' +笑 = 0 +; # Example: 'コク(笑 サイ+(D オン))', +笑^ = 0 +; # 泣きながら発話 +; # 0 to remain, 1 to delete +; # Example: '(泣 ドンナニ)' +泣 = 0 +泣^ = 0 +; # 咳をしながら発話 +; # 0 to remain, 1 to delete +; # Example: 'シャ(咳 リン) ノ' +咳 = 0 +; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' +咳^ = 0 +; # ささやき声や独り言などの小さな声 +; # 0 to remain, 1 to delete +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +L = 0 +; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' +L^ = 0 + +[REPLACEMENTS] +; # ボーカルフライなどで母音が同定できない場合 + = +; # 「うん/うーん/ふーん」の音の特定が困難な場合 + = +; # 非語彙的な母音の引き延ばし + = +; # 非語彙的な子音の引き延ばし + = +; # 言語音と独立に講演者の笑いが生じている場合 +<笑> = +; # 言語音と独立に講演者の咳が生じている場合 +<咳> = +; # 言語音と独立に講演者の息が生じている場合 +<息> = +; # 講演者の泣き声 +<泣> = +; # 聴衆(司会者なども含む)の発話 +<フロア発話> = +; # 聴衆の笑い +<フロア笑> = +; # 聴衆の拍手 +<拍手> = +; # 講演者が発表中に用いたデモンストレーションの音声 +<デモ> = +; # 学会講演に発表時間を知らせるためにならすベルの音 +<ベル> = +; # 転記単位全体が再度読み直された場合 +<朗読間違い> = +; # 上記以外の音で特に目立った音 +<雑音> = +; # 0.2秒以上のポーズ +

= +; # Redacted information, for R +; # It is \x00D7 multiplication sign, not your normal 'x' +× = × + +[FIELDS] +; # Time information for segment +time = 3 +; # Word surface +surface = 5 +; # Word surface root form without CSJ tags +notag = 9 +; # Part Of Speech +pos1 = 11 +; # Conjugated Form +cForm = 12 +; # Conjugation Type +cType1 = 13 +; # Subcategory of POS +pos2 = 14 +; # Euphonic Change / Subcategory of Conjugation Type +cType2 = 15 +; # Other information +other = 16 +; # Pronunciation for lexicon +pron = 10 +; # Speaker ID +spk_id = 2 + +[KATAKANA2ROMAJI] +ア = 'a +イ = 'i +ウ = 'u +エ = 'e +オ = 'o +カ = ka +キ = ki +ク = ku +ケ = ke +コ = ko +ガ = ga +ギ = gi +グ = gu +ゲ = ge +ゴ = go +サ = sa +シ = si +ス = su +セ = se +ソ = so +ザ = za +ジ = zi +ズ = zu +ゼ = ze +ゾ = zo +タ = ta +チ = ti +ツ = tu +テ = te +ト = to +ダ = da +ヂ = di +ヅ = du +デ = de +ド = do +ナ = na +ニ = ni +ヌ = nu +ネ = ne +ノ = no +ハ = ha +ヒ = hi +フ = hu +ヘ = he +ホ = ho +バ = ba +ビ = bi +ブ = bu +ベ = be +ボ = bo +パ = pa +ピ = pi +プ = pu +ペ = pe +ポ = po +マ = ma +ミ = mi +ム = mu +メ = me +モ = mo +ヤ = ya +ユ = yu +ヨ = yo +ラ = ra +リ = ri +ル = ru +レ = re +ロ = ro +ワ = wa +ヰ = we +ヱ = wi +ヲ = wo +ン = ŋ +ッ = q +ー = - +キャ = kǐa +キュ = kǐu +キョ = kǐo +ギャ = gǐa +ギュ = gǐu +ギョ = gǐo +シャ = sǐa +シュ = sǐu +ショ = sǐo +ジャ = zǐa +ジュ = zǐu +ジョ = zǐo +チャ = tǐa +チュ = tǐu +チョ = tǐo +ヂャ = dǐa +ヂュ = dǐu +ヂョ = dǐo +ニャ = nǐa +ニュ = nǐu +ニョ = nǐo +ヒャ = hǐa +ヒュ = hǐu +ヒョ = hǐo +ビャ = bǐa +ビュ = bǐu +ビョ = bǐo +ピャ = pǐa +ピュ = pǐu +ピョ = pǐo +ミャ = mǐa +ミュ = mǐu +ミョ = mǐo +リャ = rǐa +リュ = rǐu +リョ = rǐo +ァ = a +ィ = i +ゥ = u +ェ = e +ォ = o +ヮ = ʍ +ヴ = vu +ャ = ǐa +ュ = ǐu +ョ = ǐo + diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py new file mode 100644 index 000000000..c9de21073 --- /dev/null +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from lhotse import CutSet, load_manifest + +ARGPARSE_DESCRIPTION = """ +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in +pruned_transducer_stateless5/train.py for usage. +""" + + +def get_parser(): + parser = argparse.ArgumentParser( + description=ARGPARSE_DESCRIPTION, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--manifest-dir", type=Path, help="Path to cutset manifests" + ) + + return parser.parse_args() + + +def main(): + args = get_parser() + + for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"): + + cuts: CutSet = load_manifest(path) + + print("\n---------------------------------\n") + print(path.name + ":") + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +## eval1 +Cuts count: 1272 +Total duration (hh:mm:ss): 01:50:07 +Speech duration (hh:mm:ss): 01:50:07 (100.0%) +Duration statistics (seconds): +mean 5.2 +std 3.9 +min 0.2 +25% 1.9 +50% 4.0 +75% 8.1 +99% 14.3 +99.5% 14.7 +99.9% 16.0 +max 16.9 +Recordings available: 1272 +Features available: 1272 +Supervisions available: 1272 +SUPERVISION custom fields: +- fluent (in 1272 cuts) +- disfluent (in 1272 cuts) +- number (in 1272 cuts) +- symbol (in 1272 cuts) + +## eval2 +Cuts count: 1292 +Total duration (hh:mm:ss): 01:56:50 +Speech duration (hh:mm:ss): 01:56:50 (100.0%) +Duration statistics (seconds): +mean 5.4 +std 3.9 +min 0.1 +25% 2.1 +50% 4.6 +75% 8.6 +99% 14.1 +99.5% 15.2 +99.9% 16.1 +max 16.9 +Recordings available: 1292 +Features available: 1292 +Supervisions available: 1292 +SUPERVISION custom fields: +- fluent (in 1292 cuts) +- number (in 1292 cuts) +- symbol (in 1292 cuts) +- disfluent (in 1292 cuts) + +## eval3 +Cuts count: 1385 +Total duration (hh:mm:ss): 01:19:21 +Speech duration (hh:mm:ss): 01:19:21 (100.0%) +Duration statistics (seconds): +mean 3.4 +std 3.0 +min 0.2 +25% 1.2 +50% 2.5 +75% 4.6 +99% 12.7 +99.5% 13.7 +99.9% 15.0 +max 15.9 +Recordings available: 1385 +Features available: 1385 +Supervisions available: 1385 +SUPERVISION custom fields: +- number (in 1385 cuts) +- symbol (in 1385 cuts) +- fluent (in 1385 cuts) +- disfluent (in 1385 cuts) + +## valid +Cuts count: 4000 +Total duration (hh:mm:ss): 05:08:09 +Speech duration (hh:mm:ss): 05:08:09 (100.0%) +Duration statistics (seconds): +mean 4.6 +std 3.8 +min 0.1 +25% 1.5 +50% 3.4 +75% 7.0 +99% 13.8 +99.5% 14.8 +99.9% 16.0 +max 17.3 +Recordings available: 4000 +Features available: 4000 +Supervisions available: 4000 +SUPERVISION custom fields: +- fluent (in 4000 cuts) +- symbol (in 4000 cuts) +- disfluent (in 4000 cuts) +- number (in 4000 cuts) + +## train +Cuts count: 1291134 +Total duration (hh:mm:ss): 1596:37:27 +Speech duration (hh:mm:ss): 1596:37:27 (100.0%) +Duration statistics (seconds): +mean 4.5 +std 3.6 +min 0.0 +25% 1.6 +50% 3.3 +75% 6.4 +99% 14.0 +99.5% 14.8 +99.9% 16.6 +max 27.8 +Recordings available: 1291134 +Features available: 1291134 +Supervisions available: 1291134 +SUPERVISION custom fields: +- disfluent (in 1291134 cuts) +- fluent (in 1291134 cuts) +- symbol (in 1291134 cuts) +- number (in 1291134 cuts) +""" diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py new file mode 100644 index 000000000..e4d996871 --- /dev/null +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet + +ARGPARSE_DESCRIPTION = """ +This script gathers all training transcripts of the specified {trans_mode} type +and produces a token_list that would be output set of the ASR system. + +It splits transcripts by whitespace into lists, then, for each word in the +list, if the word does not appear in the list of user-defined multicharacter +strings, it further splits that word into individual characters to be counted +into the output token set. + +It outputs 4 files into the lang directory: +- trans_mode: the name of transcript mode. If trans_mode was not specified, + this will be an empty file. +- userdef_string: a list of user defined strings that should not be split + further into individual characters. By default, it contains "", "", + "" +- words_len: the total number of tokens in the output set. +- words.txt: a list of tokens in the output set. The length matches words_len. + +""" + + +def get_args(): + parser = argparse.ArgumentParser( + description=ARGPARSE_DESCRIPTION, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--train-cut", type=Path, required=True, help="Path to the train cut" + ) + + parser.add_argument( + "--trans-mode", + type=str, + default=None, + help=( + "Name of the transcript mode to use. " + "If lang-dir is not set, this will also name the lang-dir" + ), + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=None, + help=( + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" + ), + ) + + parser.add_argument( + "--userdef-string", + type=Path, + default=None, + help="Multicharacter strings that do not need to be split", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + logging.basicConfig( + format=( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" + ), + level=logging.INFO, + ) + + if not args.lang_dir: + p = "lang_char" + if args.trans_mode: + p += f"_{args.trans_mode}" + args.lang_dir = Path(p) + + if args.userdef_string: + args.userdef_string = set(args.userdef_string.read_text().split()) + else: + args.userdef_string = set() + + sysdef_string = ["", "", ""] + args.userdef_string.update(sysdef_string) + + train_set: CutSet = CutSet.from_file(args.train_cut) + + words = set() + logging.info( + f"Creating vocabulary from {args.train_cut.name}" + f" at {args.trans_mode} mode." + ) + for cut in train_set: + try: + text: str = ( + cut.supervisions[0].custom[args.trans_mode] + if args.trans_mode + else cut.supervisions[0].text + ) + except KeyError: + raise KeyError( + f"Could not find {args.trans_mode} in " + f"{cut.supervisions[0].custom}" + ) + for t in text.split(): + if t in args.userdef_string: + words.add(t) + else: + words.update(c for c in list(t)) + + words -= set(sysdef_string) + words = sorted(words) + words = [""] + words + ["", ""] + + args.lang_dir.mkdir(parents=True, exist_ok=True) + (args.lang_dir / "words.txt").write_text( + "\n".join(f"{word}\t{i}" for i, word in enumerate(words)) + ) + + (args.lang_dir / "words_len").write_text(f"{len(words)}") + + (args.lang_dir / "userdef_string").write_text( + "\n".join(args.userdef_string) + ) + + (args.lang_dir / "trans_mode").write_text(args.trans_mode) + logging.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py new file mode 100644 index 000000000..0c4c6c1ea --- /dev/null +++ b/egs/csj/ASR/local/validate_manifest.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within cut time bounds + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + s = c.supervisions[0] + + # Removed because when the cuts were trimmed from supervisions, + # the start time of the supervision can be lesser than cut start time. + # https://github.com/lhotse-speech/lhotse/issues/813 + # if s.start < c.start: + # raise ValueError( + # f"{c.id}: Supervision start time {s.start} is less " + # f"than cut start time {c.start}" + # ) + + if s.end > c.end: + raise ValueError( + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" + ) + + +def main(): + args = get_args() + + manifest = Path(args.manifest) + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest(manifest) + assert isinstance(cut_set, CutSet) + + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh new file mode 100755 index 000000000..269c1ec9a --- /dev/null +++ b/egs/csj/ASR/prepare.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash +# We assume the following directories are downloaded. +# +# - $csj_dir +# CSJ is assumed to be the USB-type directory, which should contain the following subdirectories:- +# - DATA (not used in this script) +# - DOC (not used in this script) +# - MODEL (not used in this script) +# - MORPH +# - LDB (not used in this script) +# - SUWDIC (not used in this script) +# - SDB +# - core +# - ... +# - noncore +# - ... +# - PLABEL (not used in this script) +# - SUMMARY (not used in this script) +# - TOOL (not used in this script) +# - WAV +# - core +# - ... +# - noncore +# - ... +# - XML (not used in this script) +# +# - $musan_dir +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# - music +# - noise +# - speech +# +# By default, this script produces the original transcript like kaldi and espnet. Optionally, you +# can generate other transcript formats by supplying your own config files. A few examples of these +# config files can be found in local/conf. + +set -eou pipefail + +nj=8 +stage=-1 +stop_stage=100 + +csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ +musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan +trans_dir=$csj_dir/retranscript +csj_fbank_dir=/mnt/host/csj_data/fbank +musan_fbank_dir=$musan_dir/fbank +csj_manifest_dir=data/manifests +musan_manifest_dir=$musan_dir/manifests + +. shared/parse_options.sh || exit 1 + +mkdir -p data + +log() { + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare CSJ manifest" + # If you want to generate more transcript modes, append the path to those config files at c. + # Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini + # NOTE: In case multiple config files are supplied, the second config file and onwards will inherit + # the segment boundaries of the first config file. + if [ ! -e $csj_manifest_dir/.librispeech.done ]; then + lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4 + touch $csj_manifest_dir/.librispeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + mkdir -p $musan_manifest_dir + if [ ! -e $musan_manifest_dir/.musan.done ]; then + lhotse prepare musan $musan_dir $musan_manifest_dir + touch $musan_manifest_dir/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute CSJ fbank" + if [ ! -e $csj_fbank_dir/.csj-validated.done ]; then + python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ + --fbank-dir $csj_fbank_dir + parts=( + train + valid + eval1 + eval2 + eval3 + ) + for part in ${parts[@]}; do + python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz + done + touch $csj_fbank_dir/.csj-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare CSJ lang" + modes=disfluent + + # If you want prepare the lang directory for other transcript modes, just append + # the names of those modes behind. An example is shown as below:- + # modes="$modes fluent symbol number" + + for mode in ${modes[@]}; do + python local/prepare_lang_char.py --trans-mode $mode \ + --train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \ + --lang-dir lang_char_$mode + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for musan" + mkdir -p $musan_fbank_dir + + if [ ! -e $musan_fbank_dir/.musan.done ]; then + python local/compute_fbank_musan.py --manifest-dir $musan_manifest_dir --fbank-dir $musan_fbank_dir + touch $musan_fbank_dir/.musan.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Show manifest statistics" + python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt + cat $csj_manifest_dir/manifest_statistics.txt +fi \ No newline at end of file diff --git a/egs/csj/ASR/shared b/egs/csj/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/csj/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 36e0c7aea..6fac07f93 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -253,7 +253,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) if not self.normalize_before: src = self.norm_conv(src) @@ -890,11 +892,16 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -908,6 +915,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) if self.use_batchnorm: x = self.norm(x) diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index f4b438aad..51406667e 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -173,13 +173,13 @@ def get_params() -> AttributeDict: def post_processing( - results: List[Tuple[List[str], List[str]]], -) -> List[Tuple[List[str], List[str]]]: + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: new_results = [] - for ref, hyp in results: + for key, ref, hyp in results: new_ref = asr_text_post_processing(" ".join(ref)).split() new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((new_ref, new_hyp)) + new_results.append((key, new_ref, new_hyp)) return new_results @@ -408,7 +408,7 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -502,7 +502,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.method == "attention-decoder": # Set it to False since there are too many logs. diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index d7ecc3fdc..5849a3471 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -203,13 +203,13 @@ def get_parser(): def post_processing( - results: List[Tuple[List[str], List[str]]], -) -> List[Tuple[List[str], List[str]]]: + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: new_results = [] - for ref, hyp in results: + for key, ref, hyp in results: new_ref = asr_text_post_processing(" ".join(ref)).split() new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((new_ref, new_hyp)) + new_results.append((key, new_ref, new_hyp)) return new_results @@ -340,7 +340,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -407,7 +407,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 8a27b4b63..d5a67b619 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,12 +1,100 @@ ## Results -#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset) +### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) -[lstm_transducer_stateless2](./lstm_transducer_stateless2) +#### [lstm_transducer_stateless3](./lstm_transducer_stateless3) + +It implements LSTM model with mechanisms in reworked model for streaming ASR. +Gradient filter is applied inside each lstm module to stabilize the training. + +See for more details. + +##### training on full librispeech + +This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496. + +The WERs are: + +| | test-clean | test-other | comment | decoding mode | +|-------------------------------------|------------|------------|----------------------|----------------------| +| greedy search (max sym per frame 1) | 3.66 | 9.51 | --epoch 40 --avg 15 | simulated streaming | +| greedy search (max sym per frame 1) | 3.66 | 9.48 | --epoch 40 --avg 15 | streaming | +| fast beam search | 3.55 | 9.33 | --epoch 40 --avg 15 | simulated streaming | +| fast beam search | 3.57 | 9.25 | --epoch 40 --avg 15 | streaming | +| modified beam search | 3.55 | 9.28 | --epoch 40 --avg 15 | simulated streaming | +| modified beam search | 3.54 | 9.25 | --epoch 40 --avg 15 | streaming | + +Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time. + + +The training command is: + +```bash +./lstm_transducer_stateless3/train.py \ + --world-size 4 \ + --num-epochs 40 \ + --start-epoch 1 \ + --exp-dir lstm_transducer_stateless3/exp \ + --full-libri 1 \ + --max-duration 500 \ + --master-port 12325 \ + --num-encoder-layers 12 \ + --grad-norm-threshold 25.0 \ + --rnn-hidden-size 1024 +``` + +The tensorboard log can be found at + + +The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is: +```bash +for decoding_method in greedy_search fast_beam_search modified_beam_search; do + ./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 15 \ + --exp-dir lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method $decoding_method \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --beam-size 4 +done +``` + +The streaming decoding command using greedy search, fast beam search, and modified beam search is: +```bash +for decoding_method in greedy_search fast_beam_search modified_beam_search; do + ./lstm_transducer_stateless3/streaming_decode.py \ + --epoch 40 \ + --avg 15 \ + --exp-dir lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method $decoding_method \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --beam-size 4 +done +``` + +Pretrained models, training logs, decoding logs, and decoding results +are available at + + + +### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset) + +#### [lstm_transducer_stateless2](./lstm_transducer_stateless2) See for more details. - The WERs are: | | test-clean | test-other | comment | @@ -18,6 +106,7 @@ The WERs are: | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | | fast_beam_search | 2.77 | 7.29 | --iter 472000 --avg 18 | + The training command is: ```bash @@ -70,15 +159,16 @@ Pretrained models, training logs, decoding logs, and decoding results are available at -#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T) -[lstm_transducer_stateless](./lstm_transducer_stateless) +### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T) + +#### [lstm_transducer_stateless](./lstm_transducer_stateless) It implements LSTM model with mechanisms in reworked model for streaming ASR. See for more details. -#### training on full librispeech +##### training on full librispeech This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496. @@ -165,7 +255,7 @@ It is modified from [torchaudio](https://github.com/pytorch/audio). See for more details. -#### With lower latency setup, training on full librispeech +##### With lower latency setup, training on full librispeech In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively. @@ -316,7 +406,7 @@ Pretrained models, training logs, decoding logs, and decoding results are available at -#### With higher latency setup, training on full librispeech +##### With higher latency setup, training on full librispeech In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively. @@ -851,14 +941,14 @@ Pre-trained models, training and decoding logs, and decoding results are availab ### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T) -[conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless) +#### [conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless) It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module for streaming ASR. It is modified from [torchaudio](https://github.com/pytorch/audio). See for more details. -#### Training on full librispeech +##### Training on full librispeech In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively. @@ -1011,7 +1101,7 @@ are available at ### LibriSpeech BPE training results (Pruned Stateless Emformer RNN-T) -[pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2) +#### [pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2) Use . @@ -1079,7 +1169,7 @@ results at: ### LibriSpeech BPE training results (Pruned Stateless Transducer 5) -[pruned_transducer_stateless5](./pruned_transducer_stateless5) +#### [pruned_transducer_stateless5](./pruned_transducer_stateless5) Same as `Pruned Stateless Transducer 2` but with more layers. @@ -1092,7 +1182,7 @@ The notations `large` and `medium` below are from the [Conformer](https://arxiv. paper, where the large model has about 118 M parameters and the medium model has 30.8 M parameters. -#### Large +##### Large Number of model parameters 118129516 (i.e, 118.13 M). @@ -1152,7 +1242,7 @@ results at: -#### Medium +##### Medium Number of model parameters 30896748 (i.e, 30.9 M). @@ -1212,7 +1302,7 @@ results at: -#### Baseline-2 +##### Baseline-2 It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim). @@ -1273,13 +1363,13 @@ results at: ### LibriSpeech BPE training results (Pruned Stateless Transducer 4) -[pruned_transducer_stateless4](./pruned_transducer_stateless4) +#### [pruned_transducer_stateless4](./pruned_transducer_stateless4) This version saves averaged model during training, and decodes with averaged model. See for details about the idea of model averaging. -#### Training on full librispeech +##### Training on full librispeech See @@ -1355,7 +1445,7 @@ Pretrained models, training logs, decoding logs, and decoding results are available at -#### Training on train-clean-100 +##### Training on train-clean-100 See @@ -1392,7 +1482,7 @@ The tensorboard log can be found at ### LibriSpeech BPE training results (Pruned Stateless Transducer 3, 2022-04-29) -[pruned_transducer_stateless3](./pruned_transducer_stateless3) +#### [pruned_transducer_stateless3](./pruned_transducer_stateless3) Same as `Pruned Stateless Transducer 2` but using the XL subset from [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data. @@ -1606,10 +1696,10 @@ can be found at ### LibriSpeech BPE training results (Pruned Transducer 2) -[pruned_transducer_stateless2](./pruned_transducer_stateless2) +#### [pruned_transducer_stateless2](./pruned_transducer_stateless2) This is with a reworked version of the conformer encoder, with many changes. -#### Training on fulll librispeech +##### Training on full librispeech Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`. See @@ -1658,7 +1748,7 @@ can be found at -#### Training on train-clean-100: +##### Training on train-clean-100: Trained with 1 job: ``` diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 36e0c7aea..6fac07f93 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -253,7 +253,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) if not self.normalize_before: src = self.norm_conv(src) @@ -890,11 +892,16 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -908,6 +915,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) if self.use_batchnorm: x = self.norm(x) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7a8fb2130..3f3b1acda 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -480,7 +480,7 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -577,7 +577,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], ): if params.method in ("attention-decoder", "rnn-lm"): # Set it to False since there are too many logs. diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index fb11a5fc8..b906d2650 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -268,7 +268,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = src + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -921,11 +923,16 @@ class ConvolutionModule(nn.Module): initial_scale=0.25, ) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -941,6 +948,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) x = self.deriv_balancer2(x) diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 6b9da12a9..97f2f2d39 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -587,7 +587,7 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -684,7 +684,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.method in ("attention-decoder", "rnn-lm"): # Set it to False since there are too many logs. diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index b5f22825d..97c8d83a2 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -247,7 +247,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) + src = residual + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) if not self.normalize_before: src = self.norm_conv(src) @@ -878,11 +880,16 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -896,6 +903,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) x = self.activation(self.norm(x)) diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 23372034a..fc9861489 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -404,7 +404,7 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -487,7 +487,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.method == "attention-decoder": # Set it to False since there are too many logs. diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index a03fe2684..620d69a19 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -366,7 +366,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -436,7 +436,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index d204a9d75..98b8290b5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -366,7 +366,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -436,7 +436,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index bfc158e0a..27414d717 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -496,7 +496,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -570,7 +570,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 0d268ab07..c54a4c478 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -116,6 +116,8 @@ class RNN(EncoderInterface): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. """ def __init__( @@ -129,6 +131,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -142,7 +145,13 @@ class RNN(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx self.num_encoder_layers = num_encoder_layers self.d_model = d_model @@ -209,7 +218,13 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() @@ -359,7 +374,7 @@ class RNNEncoderLayer(nn.Module): # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) + src = self.dropout(src_lstm) + src # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -505,6 +520,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + is_pnnx: bool = False, ) -> None: """ Args: @@ -517,6 +533,9 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. """ assert in_channels >= 9 super().__init__() @@ -559,6 +578,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -572,9 +595,15 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 21ae563cb..420202cad 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -497,7 +497,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -571,7 +571,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index a1ed6b3b1..190673638 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -74,6 +74,29 @@ with the following commands: git lfs install git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 # You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp + +(3) Export to ONNX format + +./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Please see ./streaming-onnx-decode.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. """ import argparse @@ -169,6 +192,35 @@ def get_parser(): """, ) + parser.add_argument( + "--pnnx", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace for later + converting to PNNX. It will generate 3 files: + - encoder_jit_trace-pnnx.pt + - decoder_jit_trace-pnnx.pt + - joiner_jit_trace-pnnx.pt + """, + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit and --pnnx are ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + parser.add_argument( "--context-size", type=int, @@ -254,6 +306,215 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has 3 inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + - states: a tuple containing: + - h0: a tensor of shape (num_layers, N, proj_size) + - c0: a tensor of shape (num_layers, N, hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + - states: a tuple containing: + - next_h0: a tensor of shape (num_layers, N, proj_size) + - next_c0: a tensor of shape (num_layers, N, hidden_size) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + N = 1 + x = torch.zeros(N, 9, 80, dtype=torch.float32) + x_lens = torch.tensor([9], dtype=torch.int64) + h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) + c = torch.rand( + encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size + ) + + warmup = 1.0 + torch.onnx.export( + encoder_model, # use torch.jit.trace() internally + (x, x_lens, (h, c), warmup), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "h", "c", "warmup"], + output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "h": {1: "N"}, + "c": {1: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "next_h": {1: "N"}, + "next_c": {1: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "projected_encoder_out", + "projected_decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "projected_encoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -277,6 +538,10 @@ def main(): logging.info(params) + if params.pnnx: + params.is_pnnx = params.pnnx + logging.info("For PNNX") + logging.info("About to create model") model = get_transducer_model(params, enable_giga=False) @@ -371,7 +636,44 @@ def main(): model.to("cpu") model.eval() - if params.jit_trace is True: + if params.onnx: + logging.info("Export model to ONNX format") + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + opset_version = 11 + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + elif params.pnnx: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + elif params.jit_trace is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace.pt" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py new file mode 100644 index 000000000..dba6eb520 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py @@ -0,0 +1,102 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LSTMP(nn.Module): + """LSTM with projection. + + PyTorch does not support exporting LSTM with projection to ONNX. + This class reimplements LSTM with projection using basic matrix-matrix + and matrix-vector operations. It is not intended for training. + """ + + def __init__(self, lstm: nn.LSTM): + """ + Args: + lstm: + LSTM with proj_size. We support only uni-directional, + 1-layer LSTM with projection at present. + """ + super().__init__() + assert lstm.bidirectional is False, lstm.bidirectional + assert lstm.num_layers == 1, lstm.num_layers + assert 0 < lstm.proj_size < lstm.hidden_size, ( + lstm.proj_size, + lstm.hidden_size, + ) + + assert lstm.batch_first is False, lstm.batch_first + + state_dict = lstm.state_dict() + + w_ih = state_dict["weight_ih_l0"] + w_hh = state_dict["weight_hh_l0"] + + b_ih = state_dict["bias_ih_l0"] + b_hh = state_dict["bias_hh_l0"] + + w_hr = state_dict["weight_hr_l0"] + self.input_size = lstm.input_size + self.proj_size = lstm.proj_size + self.hidden_size = lstm.hidden_size + + self.w_ih = w_ih + self.w_hh = w_hh + self.b = b_ih + b_hh + self.w_hr = w_hr + + def forward( + self, + input: torch.Tensor, + hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input: + A tensor of shape [T, N, hidden_size] + hx: + A tuple containing: + - h0: a tensor of shape (1, N, proj_size) + - c0: a tensor of shape (1, N, hidden_size) + Returns: + Return a tuple containing: + - output: a tensor of shape (T, N, proj_size). + - A tuple containing: + - h: a tensor of shape (1, N, proj_size) + - c: a tensor of shape (1, N, hidden_size) + + """ + x_list = input.unbind(dim=0) # We use batch_first=False + + if hx is not None: + h0, c0 = hx + else: + h0 = torch.zeros(1, input.size(1), self.proj_size) + c0 = torch.zeros(1, input.size(1), self.hidden_size) + h0 = h0.squeeze(0) + c0 = c0.squeeze(0) + y_list = [] + for x in x_list: + gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh) + i, f, g, o = gates.chunk(4, dim=1) + + i = i.sigmoid() + f = f.sigmoid() + g = g.tanh() + o = o.sigmoid() + + c = f * c0 + i * g + h = o * c.tanh() + + h = F.linear(h, self.w_hr) + y_list.append(h) + + c0 = c + h0 = h + + y = torch.stack(y_list, dim=0) + + return y, (h0.unsqueeze(0), c0.unsqueeze(0)) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py new file mode 100755 index 000000000..410de8d3d --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# flake8: noqa +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + ./lstm_transducer_stateless2/ncnn-decode.py \ + --bpe-model-filename ./data/lang_bpe_500/bpe.model \ + --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ + --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ + --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ + --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ + --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \ + --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \ + ./test_wavs/1089-134686-0001.wav +""" + +import argparse +import logging +from typing import List + +import kaldifeat +import ncnn +import sentencepiece as spm +import torch +import torchaudio + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model-filename", + type=str, + help="Path to bpe.model", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.use_packing_layout = False + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.use_packing_layout = False + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder(self, x, states): + with self.encoder_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + x_lens = torch.tensor([x.size(0)], dtype=torch.float32) + ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) + ex.input("in2", ncnn.Mat(states[0].numpy()).clone()) + ex.input("in3", ncnn.Mat(states[1].numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + + ret, ncnn_out1 = ex.extract("out1") + assert ret == 0, ret + + ret, ncnn_out2 = ex.extract("out2") + assert ret == 0, ret + + ret, ncnn_out3 = ex.extract("out3") + assert ret == 0, ret + + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) + hx = torch.from_numpy(ncnn_out2.numpy()).clone() + cx = torch.from_numpy(ncnn_out3.numpy()).clone() + return encoder_out, encoder_out_lens, hx, cx + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search(model: Model, encoder_out: torch.Tensor): + assert encoder_out.ndim == 2 + T = encoder_out.size(0) + + context_size = 2 + blank_id = 0 # hard-code to 0 + hyp = [blank_id] * context_size + + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + + decoder_out = model.run_decoder(decoder_input).squeeze(0) + # print(decoder_out.shape) # (512,) + + for t in range(T): + encoder_out_t = encoder_out[t] + joiner_out = model.run_joiner(encoder_out_t, decoder_out) + # print(joiner_out.shape) # [500] + y = joiner_out.argmax(dim=0).tolist() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + return hyp[context_size:] + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model_filename) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + + logging.info("Decoding started") + features = fbank(wave_samples) + + num_encoder_layers = 12 + d_model = 512 + rnn_hidden_size = 1024 + + states = ( + torch.zeros(num_encoder_layers, d_model), + torch.zeros( + num_encoder_layers, + rnn_hidden_size, + ), + ) + + encoder_out, encoder_out_lens, hx, cx = model.run_encoder(features, states) + hyp = greedy_search(model, encoder_out) + logging.info(sound_file) + logging.info(sp.decode(hyp)) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py new file mode 100755 index 000000000..e47a05a9e --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# flake8: noqa +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from typing import List, Optional + +import ncnn +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model-filename", + type=str, + help="Path to bpe.model", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + decoder_net.opt.use_packing_layout = False + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.opt.use_packing_layout = False + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder(self, x, states): + with self.encoder_net.create_extractor() as ex: + # ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + x_lens = torch.tensor([x.size(0)], dtype=torch.float32) + ex.input("in1", ncnn.Mat(x_lens.numpy()).clone()) + ex.input("in2", ncnn.Mat(states[0].numpy()).clone()) + ex.input("in3", ncnn.Mat(states[1].numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + + ret, ncnn_out1 = ex.extract("out1") + assert ret == 0, ret + + ret, ncnn_out2 = ex.extract("out2") + assert ret == 0, ret + + ret, ncnn_out3 = ex.extract("out3") + assert ret == 0, ret + + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) + hx = torch.from_numpy(ncnn_out2.numpy()).clone() + cx = torch.from_numpy(ncnn_out3.numpy()).clone() + return encoder_out, encoder_out_lens, hx, cx + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + # ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + # ex.set_num_threads(10) + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 1 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor( + hyp, dtype=torch.int32 + ) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + joiner_out = model.run_joiner(encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model_filename) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + num_encoder_layers = 12 + batch_size = 1 + d_model = 512 + rnn_hidden_size = 1024 + + states = ( + torch.zeros(num_encoder_layers, batch_size, d_model), + torch.zeros( + num_encoder_layers, + batch_size, + rnn_hidden_size, + ), + ) + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = 9 + offset = 4 + + chunk = 3200 # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) + states = (hx, cx) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + online_fbank.accept_waveform( + sampling_rate=sample_rate, waveform=torch.zeros(8000, dtype=torch.int32) + ) + + online_fbank.input_finished() + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) + states = (hx, cx) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + + logging.info(sound_file) + logging.info(sp.decode(hyp[context_size:])) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py new file mode 100755 index 000000000..1c9ec3e89 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./lstm_transducer_stateless2/onnx-streaming-decode.py \ + --encoder-model-filename ./lstm_transducer_stateless2/exp/encoder.onnx \ + --decoder-model-filename ./lstm_transducer_stateless2/exp/decoder.onnx \ + --joiner-model-filename ./lstm_transducer_stateless2/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./lstm_transducer_stateless2/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./lstm_transducer_stateless2/exp/joiner_decoder_proj.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model-filename", + type=str, + help="Path to bpe.model", + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser.parse_args() + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +class Model: + def __init__(self, args): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 5 + session_opts.intra_op_num_threads = 5 + self.session_opts = session_opts + + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + self.init_joiner_encoder_proj(args) + self.init_joiner_decoder_proj(args) + + def init_encoder(self, args): + self.encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, args): + self.decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=self.session_opts, + ) + + def init_joiner(self, args): + self.joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=self.session_opts, + ) + + def init_joiner_encoder_proj(self, args): + self.joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=self.session_opts, + ) + + def init_joiner_decoder_proj(self, args): + self.joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=self.session_opts, + ) + + def run_encoder( + self, x, h0, c0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (N, T, C) + h0: + A tensor of shape (num_layers, N, proj_size) + c0: + A tensor of shape (num_layers, N, hidden_size) + Returns: + Return a tuple containing: + - encoder_out: A tensor of shape (N, T', C') + - next_h0: A tensor of shape (num_layers, N, proj_size) + - next_c0: A tensor of shape (num_layers, N, hidden_size) + """ + encoder_input_nodes = self.encoder.get_inputs() + encoder_out_nodes = self.encoder.get_outputs() + x_lens = torch.tensor([x.size(1)], dtype=torch.int64) + + encoder_out, encoder_out_lens, next_h0, next_c0 = self.encoder.run( + [ + encoder_out_nodes[0].name, + encoder_out_nodes[1].name, + encoder_out_nodes[2].name, + encoder_out_nodes[3].name, + ], + { + encoder_input_nodes[0].name: x.numpy(), + encoder_input_nodes[1].name: x_lens.numpy(), + encoder_input_nodes[2].name: h0.numpy(), + encoder_input_nodes[3].name: c0.numpy(), + }, + ) + return ( + torch.from_numpy(encoder_out), + torch.from_numpy(next_h0), + torch.from_numpy(next_c0), + ) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A tensor of shape (N, context_size). Its dtype is torch.int64. + Returns: + Return a tensor of shape (N, 1, decoder_out_dim). + """ + decoder_input_nodes = self.decoder.get_inputs() + decoder_output_nodes = self.decoder.get_outputs() + + decoder_out = self.decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0] + + return self.run_joiner_decoder_proj( + torch.from_numpy(decoder_out).squeeze(1) + ) + + def run_joiner( + self, + projected_encoder_out: torch.Tensor, + projected_decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + projected_encoder_out: + A tensor of shape (N, joiner_dim) + projected_decoder_out: + A tensor of shape (N, joiner_dim) + Returns: + Return a tensor of shape (N, vocab_size) + """ + joiner_input_nodes = self.joiner.get_inputs() + joiner_output_nodes = self.joiner.get_outputs() + + logits = self.joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: projected_encoder_out.numpy(), + joiner_input_nodes[1].name: projected_decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(logits) + + def run_joiner_encoder_proj( + self, + encoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A tensor of shape (N, encoder_out_dim) + Returns: + A tensor of shape (N, joiner_dim) + """ + + projected_encoder_out = self.joiner_encoder_proj.run( + [self.joiner_encoder_proj.get_outputs()[0].name], + { + self.joiner_encoder_proj.get_inputs()[ + 0 + ].name: encoder_out.numpy() + }, + )[0] + + return torch.from_numpy(projected_encoder_out) + + def run_joiner_decoder_proj( + self, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + decoder_out: + A tensor of shape (N, decoder_out_dim) + Returns: + A tensor of shape (N, joiner_dim) + """ + + projected_decoder_out = self.joiner_decoder_proj.run( + [self.joiner_decoder_proj.get_outputs()[0].name], + { + self.joiner_decoder_proj.get_inputs()[ + 0 + ].name: decoder_out.numpy() + }, + )[0] + + return torch.from_numpy(projected_decoder_out) + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + assert encoder_out.shape[0] == 1, "TODO: support batch_size > 1" + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor( + [hyp], dtype=torch.int64 + ) # (1, context_size) + decoder_out = model.run_decoder(decoder_input) + else: + assert decoder_out.shape[0] == 1 + assert hyp is not None, hyp + + projected_encoder_out = model.run_joiner_encoder_proj(encoder_out) + + joiner_out = model.run_joiner(projected_encoder_out, decoder_out) + y = joiner_out.squeeze(0).argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + sample_rate = 16000 + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model_filename) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + num_encoder_layers = 12 + batch_size = 1 + d_model = 512 + rnn_hidden_size = 1024 + + h0 = torch.zeros(num_encoder_layers, batch_size, d_model) + c0 = torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size) + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = 9 + offset = 4 + + chunk = 3200 # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + + num_processed_frames += offset + frames = torch.cat(frames, dim=0).unsqueeze(0) + encoder_out, h0, c0 = model.run_encoder(frames, h0, c0) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + online_fbank.accept_waveform( + sampling_rate=sample_rate, waveform=torch.zeros(5000, dtype=torch.float) + ) + + online_fbank.input_finished() + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0).unsqueeze(0) + encoder_out, h0, c0 = model.run_encoder(frames, h0, c0) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + + logging.info(sound_file) + logging.info(sp.decode(hyp[context_size:])) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py new file mode 100755 index 000000000..00ba224cd --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +from lstmp import LSTMP + + +def test(): + input_size = torch.randint(low=10, high=1024, size=(1,)).item() + hidden_size = torch.randint(low=10, high=1024, size=(1,)).item() + proj_size = hidden_size - 1 + lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=True, + proj_size=proj_size, + ) + lstmp = LSTMP(lstm) + + N = torch.randint(low=1, high=10, size=(1,)).item() + T = torch.randint(low=1, high=20, size=(1,)).item() + x = torch.rand(T, N, input_size) + h0 = torch.rand(1, N, proj_size) + c0 = torch.rand(1, N, hidden_size) + + y1, (h1, c1) = lstm(x, (h0, c0)) + y2, (h2, c2) = lstmp(x, (h0, c0)) + + assert torch.allclose(y1, y2, atol=1e-5), (y1 - y2).abs().max() + assert torch.allclose(h1, h2, atol=1e-5), (h1 - h2).abs().max() + assert torch.allclose(c1, c2, atol=1e-5), (c1 - c2).abs().max() + + # lstm_script = torch.jit.script(lstm) # pytorch does not support it + lstm_script = lstm + lstmp_script = torch.jit.script(lstmp) + + y3, (h3, c3) = lstm_script(x, (h0, c0)) + y4, (h4, c4) = lstmp_script(x, (h0, c0)) + + assert torch.allclose(y3, y4, atol=1e-5), (y3 - y4).abs().max() + assert torch.allclose(h3, h4, atol=1e-5), (h3 - h4).abs().max() + assert torch.allclose(c3, c4, atol=1e-5), (c3 - c4).abs().max() + + assert torch.allclose(y3, y1, atol=1e-5), (y3 - y1).abs().max() + assert torch.allclose(h3, h1, atol=1e-5), (h3 - h1).abs().max() + assert torch.allclose(c3, c1, atol=1e-5), (c3 - c1).abs().max() + + lstm_trace = torch.jit.trace(lstm, (x, (h0, c0))) + lstmp_trace = torch.jit.trace(lstmp, (x, (h0, c0))) + + y5, (h5, c5) = lstm_trace(x, (h0, c0)) + y6, (h6, c6) = lstmp_trace(x, (h0, c0)) + + assert torch.allclose(y5, y6, atol=1e-5), (y5 - y6).abs().max() + assert torch.allclose(h5, h6, atol=1e-5), (h5 - h6).abs().max() + assert torch.allclose(c5, c6, atol=1e-5), (c5 - c6).abs().max() + + assert torch.allclose(y5, y1, atol=1e-5), (y5 - y1).abs().max() + assert torch.allclose(h5, h1, atol=1e-5), (h5 - h1).abs().max() + assert torch.allclose(c5, c1, atol=1e-5), (c5 - c1).abs().max() + + +@torch.no_grad() +def main(): + test() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index b3e50c52b..9eed2dfcb 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -406,6 +406,8 @@ def get_params() -> AttributeDict: "decoder_dim": 512, # parameters for joiner "joiner_dim": 512, + # True to generate a model that can be exported via PNNX + "is_pnnx": False, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), @@ -424,6 +426,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/__init__.py b/egs/librispeech/ASR/lstm_transducer_stateless3/__init__.py new file mode 120000 index 000000000..b24e5e357 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/__init__.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/lstm_transducer_stateless3/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/beam_search.py b/egs/librispeech/ASR/lstm_transducer_stateless3/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py new file mode 100755 index 000000000..5be23c50c --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./lstm_transducer_stateless2/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./lstm_transducer_stateless3/decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + feature = torch.nn.functional.pad( + feature, + (0, 0, 0, num_tail_padded_frames), + mode="constant", + value=LOG_EPS, + ) + feature_lens += num_tail_padded_frames + + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decoder.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decoder.py new file mode 120000 index 000000000..0793c5709 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/encoder_interface.py b/egs/librispeech/ASR/lstm_transducer_stateless3/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py new file mode 100755 index 000000000..212c7bad6 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless3/export.py \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 40 \ + --avg 20 \ + --jit-trace 1 + +It will generate 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + +(2) Export `model.state_dict()` + +./lstm_transducer_stateless3/export.py \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 40 \ + --avg 20 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless3/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless3/decode.py \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py new file mode 100755 index 000000000..a3443cf0a --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless3/export.py \ + --exp-dir ./lstm_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 40 \ + --avg 15 \ + --jit-trace 1 + +Usage of this script: + +./lstm_transducer_stateless3/jit_pretrained.py \ + --encoder-model-filename ./lstm_transducer_stateless3/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless3/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless3/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + states = encoder.get_init_states(batch_size=features.size(0), device=device) + + encoder_out, encoder_out_lens, _ = encoder( + x=features, + x_lens=feature_lengths, + states=states, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/joiner.py b/egs/librispeech/ASR/lstm_transducer_stateless3/joiner.py new file mode 120000 index 000000000..815fd4bb6 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py new file mode 100644 index 000000000..90bc351f4 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -0,0 +1,860 @@ +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, + ScaledLSTM, +) +from torch import nn + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[torch.Tensor, torch.Tensor] +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Unstack the lstm states corresponding to a batch of utterances into a list + of states, where the i-th entry is the state from the i-th utterance. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + """ + hidden_states, cell_states = states + + list_hidden_states = hidden_states.unbind(dim=1) + list_cell_states = cell_states.unbind(dim=1) + + ans = [ + (h.unsqueeze(1), c.unsqueeze(1)) + for (h, c) in zip(list_hidden_states, list_cell_states) + ] + return ans + + +def stack_states( + states_list: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Stack list of lstm states corresponding to separate utterances into a single + lstm state so that it can be used as an input for lstm when those utterances + are formed into a batch. + + Args: + state_list: + Each element in state_list corresponds to the lstm state for a single + utterance. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + + + Returns: + A new state corresponding to a batch of utterances. + It is a tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + """ + hidden_states = torch.cat([s[0] for s in states_list], dim=1) + cell_states = torch.cat([s[1] for s in states_list], dim=1) + ans = (hidden_states, cell_states) + return ans + + +class RNN(EncoderInterface): + """ + Args: + num_features (int): + Number of input features. + subsampling_factor (int): + Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa + d_model (int): + Output dimension (default=512). + dim_feedforward (int): + Feedforward dimension (default=2048). + rnn_hidden_size (int): + Hidden dimension for lstm layers (default=1024). + grad_norm_threshold: + For each sequence element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. + num_encoder_layers (int): + Number of encoder layers (default=12). + dropout (float): + Dropout rate (default=0.1). + layer_dropout (float): + Dropout value for model-level warmup (default=0.075). + aux_layer_period (int): + Period of auxiliary layers used for random combiner during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 512, + dim_feedforward: int = 2048, + rnn_hidden_size: int = 1024, + grad_norm_threshold: float = 10.0, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + aux_layer_period: int = 0, + ) -> None: + super(RNN, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + encoder_layer = RNNEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + rnn_hidden_size=rnn_hidden_size, + grad_norm_threshold=grad_norm_threshold, + dropout=dropout, + layer_dropout=layer_dropout, + ) + self.encoder = RNNEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ) + if aux_layer_period > 0 + else None, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A tuple of 3 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + - updated states, whose shape is the same as the input states. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() + + if states is None: + x = self.encoder(x, warmup=warmup)[0] + # torch.jit.trace requires returned types to be the same as annotated # noqa + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_encoder_layers, + x.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(1), + self.rnn_hidden_size, + ) + x, new_states = self.encoder(x, states) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, batch_size: int = 1, device: torch.device = torch.device("cpu") + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get model initial states.""" + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.d_model), device=device + ) + cell_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.rnn_hidden_size), + device=device, + ) + return (hidden_states, cell_states) + + +class RNNEncoderLayer(nn.Module): + """ + RNNEncoderLayer is made up of lstm and feedforward networks. + For stable training, in each lstm module, gradient filter + is applied to filter out extremely large elements in batch gradients + and also the module parameters with soft masks. + + Args: + d_model: + The number of expected features in the input (required). + dim_feedforward: + The dimension of feedforward network model (default=2048). + rnn_hidden_size: + The hidden dimension of rnn layer. + grad_norm_threshold: + For each sequence element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. + dropout: + The dropout value (default=0.1). + layer_dropout: + The dropout value for model-level warmup (default=0.075). + """ + + def __init__( + self, + d_model: int, + dim_feedforward: int, + rnn_hidden_size: int, + grad_norm_threshold: float = 10.0, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(RNNEncoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) + + self.lstm = ScaledLSTM( + input_size=d_model, + hidden_size=rnn_hidden_size, + proj_size=d_model if rnn_hidden_size > d_model else 0, + num_layers=1, + dropout=0.0, + grad_norm_threshold=grad_norm_threshold, + ) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (1, N, d_model); + states[1] is the cell states of all layers, + with shape of (1, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # lstm module + if states is None: + src_lstm = self.lstm(src)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) + src_lstm, new_states = self.lstm(src, states) + src = src + self.dropout(src_lstm) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src, new_states + + +class RNNEncoder(nn.Module): + """ + RNNEncoder is a stack of N encoder layers. + + Args: + encoder_layer: + An instance of the RNNEncoderLayer() class (required). + num_layers: + The number of sub-encoder-layers in the encoder (required). + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: Optional[List[int]] = None, + ) -> None: + super(RNNEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + self.d_model = encoder_layer.d_model + self.rnn_hidden_size = encoder_layer.rnn_hidden_size + + self.aux_layers: List[int] = [] + self.combiner: Optional[nn.Module] = None + if aux_layers is not None: + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer in turn. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + if states is not None: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_layers, + src.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) + + output = src + + outputs = [] + + new_hidden_states = [] + new_cell_states = [] + + for i, mod in enumerate(self.layers): + if states is None: + output = mod(output, warmup=warmup)[0] + else: + layer_state = ( + states[0][i : i + 1, :, :], # h: (1, N, d_model) + states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) + ) + output, (h, c) = mod(output, layer_state) + new_hidden_states.append(h) + new_cell_states.append(c) + + if self.combiner is not None and i in self.aux_layers: + outputs.append(output) + + if self.combiner is not None: + output = self.combiner(outputs) + + if states is None: + new_states = (torch.empty(0), torch.empty(0)) + else: + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + + return output, new_states + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-3)//2-1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >= 9, in_channels >= 9. + out_channels + Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 9 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=0, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-3)//2-1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting(): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + feature_dim = 80 + m = RNN( + num_features=feature_dim, + d_model=512, + rnn_hidden_size=1024, + dim_feedforward=2048, + num_encoder_layers=12, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = m( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/model.py b/egs/librispeech/ASR/lstm_transducer_stateless3/model.py new file mode 120000 index 000000000..1bf04f3a4 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/model.py @@ -0,0 +1 @@ +../lstm_transducer_stateless/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/optim.py b/egs/librispeech/ASR/lstm_transducer_stateless3/optim.py new file mode 120000 index 000000000..e2deb4492 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py new file mode 100755 index 000000000..0e48fef04 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./lstm_transducer_stateless3/pretrained.py \ + --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./lstm_transducer_stateless3/pretrained.py \ + --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./lstm_transducer_stateless3/pretrained.py \ + --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./lstm_transducer_stateless3/pretrained.py \ + --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./lstm_transducer_stateless3/exp/epoch-xx.pt`. + +Note: ./lstm_transducer_stateless3/exp/pretrained.pt is generated by +./lstm_transducer_stateless3/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/scaling.py b/egs/librispeech/ASR/lstm_transducer_stateless3/scaling.py new file mode 120000 index 000000000..09d802cc4 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/lstm_transducer_stateless3/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless3/stream.py new file mode 120000 index 000000000..71ea6dff1 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/stream.py @@ -0,0 +1 @@ +../lstm_transducer_stateless/stream.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py new file mode 100755 index 000000000..cfa918ed5 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -0,0 +1,968 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless3/streaming_decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir lstm_transducer_stateless3/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method greedy_search \ + --use-averaged-model True + +(2) modified beam search +./lstm_transducer_stateless3/streaming_decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir lstm_transducer_stateless3/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method modified_beam_search \ + --use-averaged-model True \ + --beam-size 4 + +(3) fast beam search +./lstm_transducer_stateless3/streaming_decode.py \ + --epoch 40 \ + --avg 20 \ + --exp-dir lstm_transducer_stateless3/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method fast_beam_search \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" +import argparse +import logging +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lstm import LOG_EPSILON, stack_states, unstack_states +from stream import Stream +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=40, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--sampling-rate", + type=float, + default=16000, + help="Sample rate of the audio", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded in parallel", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + T = encoder_out.size(1) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (batch_size, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, +): + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + beam: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + streams: List[Stream], + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + streams: + A list of stream objects. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + assert B == len(streams) + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyps[i] + + +def decode_one_chunk( + model: nn.Module, + streams: List[Stream], + params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, +) -> List[int]: + """ + Args: + model: + The Transducer model. + streams: + A list of Stream objects. + params: + It is returned by :func:`get_params`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + A list of indexes indicating the finished streams. + """ + device = next(model.parameters()).device + + feature_list = [] + feature_len_list = [] + state_list = [] + num_processed_frames_list = [] + + for stream in streams: + # We should first get `stream.num_processed_frames` + # before calling `stream.get_feature_chunk()` + # since `stream.num_processed_frames` would be updated + num_processed_frames_list.append(stream.num_processed_frames) + feature = stream.get_feature_chunk() + feature_len = feature.size(0) + feature_list.append(feature) + feature_len_list.append(feature_len) + state_list.append(stream.states) + + features = pad_sequence( + feature_list, batch_first=True, padding_value=LOG_EPSILON + ).to(device) + feature_lens = torch.tensor(feature_len_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) + + # Make sure it has at least 1 frame after subsampling + tail_length = params.subsampling_factor + 5 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, + ) + + # Stack states of all streams + states = stack_states(state_list) + + encoder_out, encoder_out_lens, states = model.encoder( + x=features, + x_lens=feature_lens, + states=states, + ) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + beam=params.beam_size, + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + processed_lens = ( + num_processed_frames // params.subsampling_factor + + encoder_out_lens + ) + fast_beam_search_one_best( + model=model, + streams=streams, + encoder_out=encoder_out, + processed_lens=processed_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + # Update cached states of each stream + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + finished_streams = [i for i, stream in enumerate(streams) if stream.done] + return finished_streams + + +def create_streaming_feature_extractor() -> Fbank: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return Fbank(opts) + + +def decode_dataset( + cuts: CutSet, + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +): + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The Transducer model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = next(model.parameters()).device + + log_interval = 300 + + fbank = create_streaming_feature_extractor() + + decode_results = [] + streams = [] + for num, cut in enumerate(cuts): + # Each utterance has a Stream. + stream = Stream( + params=params, + cut_id=cut.id, + decoding_graph=decoding_graph, + device=device, + LOG_EPS=LOG_EPSILON, + ) + + stream.states = model.encoder.get_init_states(device=device) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + feature = fbank(samples) + stream.set_feature(feature) + stream.ground_truth = cut.supervisions[0].text + + streams.append(stream) + + while len(streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + while len(streams) > 0: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + else: + key = f"beam_size_{params.beam_size}" + + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + model=model, + params=params, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + torch.manual_seed(20220810) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/test_model.py b/egs/librispeech/ASR/lstm_transducer_stateless3/test_model.py new file mode 100755 index 000000000..03dfe1997 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/test_model.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./lstm_transducer_stateless/test_model.py +""" + +import os +from pathlib import Path + +import torch +from export import ( + export_decoder_model_jit_trace, + export_encoder_model_jit_trace, + export_joiner_model_jit_trace, +) +from lstm import stack_states, unstack_states +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + params.encoder_dim = 512 + params.rnn_hidden_size = 1024 + params.num_encoder_layers = 12 + params.aux_layer_period = 0 + params.exp_dir = Path("exp_test_model") + + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + if not os.path.exists(params.exp_dir): + os.path.mkdir(params.exp_dir) + + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + print("The model has been successfully exported using jit.trace.") + + +def test_states_stack_and_unstack(): + layer, batch, hidden, cell = 12, 100, 512, 1024 + states = ( + torch.randn(layer, batch, hidden), + torch.randn(layer, batch, cell), + ) + states2 = stack_states(unstack_states(states)) + assert torch.allclose(states[0], states2[0]) + assert torch.allclose(states[1], states2[1]) + + +def main(): + test_model() + test_states_stack_and_unstack() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/lstm_transducer_stateless3/test_scaling_converter.py new file mode 100644 index 000000000..7567dd58c --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/test_scaling_converter.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./lstm_transducer_stateless/test_scaling_converter.py +""" + +import copy + +import torch +from scaling import ( + ScaledConv1d, + ScaledConv2d, + ScaledEmbedding, + ScaledLinear, + ScaledLSTM, +) +from scaling_converter import ( + convert_scaled_to_non_scaled, + scaled_conv1d_to_conv1d, + scaled_conv2d_to_conv2d, + scaled_embedding_to_embedding, + scaled_linear_to_linear, + scaled_lstm_to_lstm, +) +from train import get_params, get_transducer_model + + +def get_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + params.encoder_dim = 512 + params.rnn_hidden_size = 1024 + params.num_encoder_layers = 12 + params.aux_layer_period = -1 + + model = get_transducer_model(params) + return model + + +def test_scaled_linear_to_linear(): + N = 5 + in_features = 10 + out_features = 20 + for bias in [True, False]: + scaled_linear = ScaledLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + linear = scaled_linear_to_linear(scaled_linear) + x = torch.rand(N, in_features) + + y1 = scaled_linear(x) + y2 = linear(x) + assert torch.allclose(y1, y2) + + jit_scaled_linear = torch.jit.script(scaled_linear) + jit_linear = torch.jit.script(linear) + + y3 = jit_scaled_linear(x) + y4 = jit_linear(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_conv1d_to_conv1d(): + in_channels = 3 + for bias in [True, False]: + scaled_conv1d = ScaledConv1d( + in_channels, + 6, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + conv1d = scaled_conv1d_to_conv1d(scaled_conv1d) + + x = torch.rand(20, in_channels, 10) + y1 = scaled_conv1d(x) + y2 = conv1d(x) + assert torch.allclose(y1, y2) + + jit_scaled_conv1d = torch.jit.script(scaled_conv1d) + jit_conv1d = torch.jit.script(conv1d) + + y3 = jit_scaled_conv1d(x) + y4 = jit_conv1d(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_conv2d_to_conv2d(): + in_channels = 1 + for bias in [True, False]: + scaled_conv2d = ScaledConv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + padding=1, + bias=bias, + ) + + conv2d = scaled_conv2d_to_conv2d(scaled_conv2d) + + x = torch.rand(20, in_channels, 10, 20) + y1 = scaled_conv2d(x) + y2 = conv2d(x) + assert torch.allclose(y1, y2) + + jit_scaled_conv2d = torch.jit.script(scaled_conv2d) + jit_conv2d = torch.jit.script(conv2d) + + y3 = jit_scaled_conv2d(x) + y4 = jit_conv2d(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_embedding_to_embedding(): + scaled_embedding = ScaledEmbedding( + num_embeddings=500, + embedding_dim=10, + padding_idx=0, + ) + embedding = scaled_embedding_to_embedding(scaled_embedding) + + for s in [10, 100, 300, 500, 800, 1000]: + x = torch.randint(low=0, high=500, size=(s,)) + scaled_y = scaled_embedding(x) + y = embedding(x) + assert torch.equal(scaled_y, y) + + +def test_scaled_lstm_to_lstm(): + input_size = 512 + batch_size = 20 + for bias in [True, False]: + for hidden_size in [512, 1024]: + scaled_lstm = ScaledLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + proj_size=0 if hidden_size == input_size else input_size, + ) + + lstm = scaled_lstm_to_lstm(scaled_lstm) + + x = torch.rand(200, batch_size, input_size) + h0 = torch.randn(1, batch_size, input_size) + c0 = torch.randn(1, batch_size, hidden_size) + + y1, (h1, c1) = scaled_lstm(x, (h0, c0)) + y2, (h2, c2) = lstm(x, (h0, c0)) + assert torch.allclose(y1, y2) + assert torch.allclose(h1, h2) + assert torch.allclose(c1, c2) + + jit_scaled_lstm = torch.jit.trace(lstm, (x, (h0, c0))) + y3, (h3, c3) = jit_scaled_lstm(x, (h0, c0)) + assert torch.allclose(y1, y3) + assert torch.allclose(h1, h3) + assert torch.allclose(c1, c3) + + +def test_convert_scaled_to_non_scaled(): + for inplace in [False, True]: + model = get_model() + model.eval() + + orig_model = copy.deepcopy(model) + + converted_model = convert_scaled_to_non_scaled(model, inplace=inplace) + + model = orig_model + + # test encoder + N = 2 + T = 100 + vocab_size = model.decoder.vocab_size + + x = torch.randn(N, T, 80, dtype=torch.float32) + x_lens = torch.full((N,), x.size(1)) + + e1, e1_lens, _ = model.encoder(x, x_lens) + e2, e2_lens, _ = converted_model.encoder(x, x_lens) + + assert torch.all(torch.eq(e1_lens, e2_lens)) + assert torch.allclose(e1, e2), (e1 - e2).abs().max() + + # test decoder + U = 50 + y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) + + d1 = model.decoder(y) + d2 = model.decoder(y) + + assert torch.allclose(d1, d2) + + # test simple projection + lm1 = model.simple_lm_proj(d1) + am1 = model.simple_am_proj(e1) + + lm2 = converted_model.simple_lm_proj(d2) + am2 = converted_model.simple_am_proj(e2) + + assert torch.allclose(lm1, lm2) + assert torch.allclose(am1, am2) + + # test joiner + e = torch.rand(2, 3, 4, 512) + d = torch.rand(2, 3, 4, 512) + + j1 = model.joiner(e, d) + j2 = converted_model.joiner(e, d) + assert torch.allclose(j1, j2) + + +@torch.no_grad() +def main(): + test_scaled_linear_to_linear() + test_scaled_conv1d_to_conv1d() + test_scaled_conv2d_to_conv2d() + test_scaled_embedding_to_embedding() + test_scaled_lstm_to_lstm() + test_convert_scaled_to_non_scaled() + + +if __name__ == "__main__": + torch.manual_seed(20220730) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py new file mode 100755 index 000000000..dc3697ae7 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -0,0 +1,1138 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./lstm_transducer_stateless3/train.py \ + --world-size 4 \ + --num-epochs 40 \ + --start-epoch 1 \ + --exp-dir lstm_transducer_stateless3/exp \ + --full-libri 1 \ + --max-duration 500 + +# For mix precision training: + +./lstm_transducer_stateless3/train.py \ + --world-size 4 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir lstm_transducer_stateless3/exp \ + --full-libri 1 \ + --max-duration 550 +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from lstm import RNN +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of RNN encoder layers..", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", + ) + + parser.add_argument( + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", + ) + + parser.add_argument( + "--aux-layer-period", + type=int, + default=0, + help="""Peroid of auxiliary layers used for randomly combined during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + """, + ) + + parser.add_argument( + "--grad-norm-threshold", + type=float, + default=25.0, + help="""For each sequence element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch.""", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=40, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "dim_feedforward": 2048, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = RNN( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + rnn_hidden_size=params.rnn_hidden_size, + grad_norm_threshold=params.grad_norm_threshold, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + aux_layer_period=params.aux_layer_period, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if ( + batch_idx % params.log_interval == 0 + and not params.print_diagnostics + ): + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if ( + batch_idx > 0 + and batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + # # overwrite it + # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] + # print(scheduler.base_lrs) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 282ce3737..2d5724d30 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -354,7 +354,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -424,7 +424,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ab6cf336c..b11fb960a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -511,7 +511,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[str, Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -585,7 +585,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[str, Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 9a0405c57..b04a74a19 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -476,8 +476,8 @@ class ConformerEncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + src_mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: """ @@ -486,8 +486,8 @@ class ConformerEncoderLayer(nn.Module): Args: src: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + src_mask: the mask for the src sequence (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. Shape: @@ -527,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -661,8 +663,8 @@ class ConformerEncoder(nn.Module): self, src: Tensor, pos_emb: Tensor, - mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + mask: Optional[Tensor] = None, warmup: float = 1.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -670,8 +672,8 @@ class ConformerEncoder(nn.Module): Args: src: the sequence to the encoder (required). pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + mask: the mask for the src sequence (optional). warmup: controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. @@ -930,7 +932,7 @@ class RelPositionMultiheadAttention(nn.Module): value: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, + need_weights: bool = False, attn_mask: Optional[Tensor] = None, left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -1057,7 +1059,7 @@ class RelPositionMultiheadAttention(nn.Module): out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, + need_weights: bool = False, attn_mask: Optional[Tensor] = None, left_context: int = 0, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -1457,6 +1459,7 @@ class ConvolutionModule(nn.Module): x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. @@ -1467,6 +1470,7 @@ class ConvolutionModule(nn.Module): right_context: How many future frames the attention can see in current chunk. Note: It's not that each individual frame has `right_context` frames + src_key_padding_mask: the mask for the src keys per batch (optional). of right context, some have more. Returns: @@ -1486,6 +1490,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) if self.causal and self.lorder > 0: if cache is None: # Make depthwise_conv causal by diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 34fd31e7e..7852dafc9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -534,7 +534,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -608,7 +608,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index cc3caecc7..8c572a9ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -16,6 +16,7 @@ import collections +import random from itertools import repeat from typing import Optional, Tuple @@ -111,6 +112,76 @@ class ActivationBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, None, None, None +class GradientFilterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + batch_dim: int, # e.g., 1 + threshold: float, # e.g., 10.0 + *params: Tensor, # module parameters + ) -> Tuple[Tensor, ...]: + if x.requires_grad: + if batch_dim < 0: + batch_dim += x.ndim + ctx.batch_dim = batch_dim + ctx.threshold = threshold + return (x,) + params + + @staticmethod + def backward( + ctx, + x_grad: Tensor, + *param_grads: Tensor, + ) -> Tuple[Tensor, ...]: + eps = 1.0e-20 + dim = ctx.batch_dim + norm_dims = [d for d in range(x_grad.ndim) if d != dim] + norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() + median_norm = norm_of_batch.median() + + cutoff = median_norm * ctx.threshold + inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) + mask = 1.0 / (inv_mask + eps) + x_grad = x_grad * mask + + avg_mask = 1.0 / (inv_mask.mean() + eps) + param_grads = [avg_mask * g for g in param_grads] + + return (x_grad, None, None) + tuple(param_grads) + + +class GradientFilter(torch.nn.Module): + """This is used to filter out elements that have extremely large gradients + in batch and the module parameters with soft masks. + + Args: + batch_dim (int): + The batch dimension. + threshold (float): + For each element in batch, its gradient will be + filtered out if the gradient norm is larger than + `grad_norm_threshold * median`, where `median` is the median + value of gradient norms of all elememts in batch. + """ + + def __init__(self, batch_dim: int = 1, threshold: float = 10.0): + super(GradientFilter, self).__init__() + self.batch_dim = batch_dim + self.threshold = threshold + + def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: + if torch.jit.is_scripting() or is_jit_tracing(): + return (x,) + params + else: + return GradientFilterFunction.apply( + x, + self.batch_dim, + self.threshold, + *params, + ) + + class BasicNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for @@ -195,7 +266,7 @@ class ScaledLinear(nn.Linear): *args, initial_scale: float = 1.0, initial_speed: float = 1.0, - **kwargs + **kwargs, ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -242,7 +313,7 @@ class ScaledConv1d(nn.Conv1d): *args, initial_scale: float = 1.0, initial_speed: float = 1.0, - **kwargs + **kwargs, ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -314,7 +385,7 @@ class ScaledConv2d(nn.Conv2d): *args, initial_scale: float = 1.0, initial_speed: float = 1.0, - **kwargs + **kwargs, ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -389,7 +460,8 @@ class ScaledLSTM(nn.LSTM): *args, initial_scale: float = 1.0, initial_speed: float = 1.0, - **kwargs + grad_norm_threshold: float = 10.0, + **kwargs, ): if "bidirectional" in kwargs: assert kwargs["bidirectional"] is False @@ -404,6 +476,10 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) + self.grad_filter = GradientFilter( + batch_dim=1, threshold=grad_norm_threshold + ) + self._reset_parameters( initial_speed ) # Overrides the reset_parameters in base class @@ -513,10 +589,14 @@ class ScaledLSTM(nn.LSTM): hx = (h_zeros, c_zeros) self.check_forward_args(input, hx, None) + + flat_weights = self._get_flat_weights() + input, *flat_weights = self.grad_filter(input, *flat_weights) + result = _VF.lstm( input, hx, - self._get_flat_weights(), + flat_weights, self.bias, self.num_layers, self.dropout, @@ -557,6 +637,7 @@ class ActivationBalancer(torch.nn.Module): max_abs: the maximum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent this. + balance_prob: the probability to apply the ActivationBalancer. """ def __init__( @@ -567,6 +648,7 @@ class ActivationBalancer(torch.nn.Module): max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0, + balance_prob: float = 0.25, ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim @@ -575,9 +657,11 @@ class ActivationBalancer(torch.nn.Module): self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs + assert 0 < balance_prob <= 1, balance_prob + self.balance_prob = balance_prob def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or is_jit_tracing(): + if random.random() >= self.balance_prob: return x else: return ActivationBalancerFunction.apply( @@ -585,7 +669,7 @@ class ActivationBalancer(torch.nn.Module): self.channel_dim, self.min_positive, self.max_positive, - self.max_factor, + self.max_factor / self.balance_prob, self.min_abs, self.max_abs, ) @@ -891,9 +975,54 @@ def _test_scaled_lstm(): assert c.shape == (1, N, dim_hidden) +def _test_grad_filter(): + threshold = 50.0 + time, batch, channel = 200, 5, 128 + grad_filter = GradientFilter(batch_dim=1, threshold=threshold) + + for i in range(2): + x = torch.randn(time, batch, channel, requires_grad=True) + w = nn.Parameter(torch.ones(5)) + b = nn.Parameter(torch.zeros(5)) + + x_out, w_out, b_out = grad_filter(x, w, b) + + w_out_grad = torch.randn_like(w) + b_out_grad = torch.randn_like(b) + x_out_grad = torch.rand_like(x) + if i % 2 == 1: + # The gradient norm of the first element must be larger than + # `threshold * median`, where `median` is the median value + # of gradient norms of all elements in batch. + x_out_grad[:, 0, :] = torch.full((time, channel), threshold) + + torch.autograd.backward( + [x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad] + ) + + print( + "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa + i % 2 == 1, + ) + + print( + "_test_grad_filter: x_out_grad norm = ", + (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), + ) + print( + "_test_grad_filter: x.grad norm = ", + (x.grad ** 2).mean(dim=(0, 2)).sqrt(), + ) + print("_test_grad_filter: w_out_grad = ", w_out_grad) + print("_test_grad_filter: w.grad = ", w.grad) + print("_test_grad_filter: b_out_grad = ", b_out_grad) + print("_test_grad_filter: b.grad = ", b.grad) + + if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() _test_double_swish_deriv() _test_scaled_lstm() + _test_grad_filter() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 72d6f656c..fa4f1e7d9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -781,7 +781,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index a4687f35d..47217ba05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -62,13 +62,20 @@ It will generates 3 files: `encoder_jit_trace.pt`, --avg 10 \ --onnx 1 -It will generate the following three files in the given `exp_dir`. +It will generate the following files in the given `exp_dir`. Check `onnx_check.py` for how to use them. - encoder.onnx - decoder.onnx - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx +Please see ./onnx_pretrained.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. (4) Export `model.state_dict()` @@ -115,7 +122,6 @@ import argparse import logging from pathlib import Path -import onnx import sentencepiece as spm import torch import torch.nn as nn @@ -213,13 +219,15 @@ def get_parser(): type=str2bool, default=False, help="""If True, --jit is ignored and it exports the model - to onnx format. Three files will be generated: + to onnx format. It will generate the following files: - encoder.onnx - decoder.onnx - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx - Check ./onnx_check.py and ./onnx_pretrained.py for how to use them. + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, ) @@ -476,65 +484,99 @@ def export_joiner_model_onnx( opset_version: int = 11, ) -> None: """Export the joiner model to ONNX format. - The exported model has two inputs: + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + - decoder_out: a tensor of shape (N, decoder_out_dim) - and has one output: + and produces one output: - - joiner_out: a tensor of shape (N, vocab_size) - - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. + - projected_decoder_out: a tensor of shape (N, joiner_dim) """ + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + joiner_dim = joiner_model.decoder_proj.weight.shape[0] - project_input = True + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + project_input = False # Note: It uses torch.jit.trace() internally torch.onnx.export( joiner_model, - (encoder_out, decoder_out, project_input), + (projected_encoder_out, projected_decoder_out, project_input), joiner_filename, verbose=False, opset_version=opset_version, - input_names=["encoder_out", "decoder_out", "project_input"], + input_names=[ + "projected_encoder_out", + "projected_decoder_out", + "project_input", + ], output_names=["logit"], dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, "logit": {0: "N"}, }, ) logging.info(f"Saved to {joiner_filename}") - -def export_all_in_one_onnx( - encoder_filename: str, - decoder_filename: str, - joiner_filename: str, - all_in_one_filename: str, -): - encoder_onnx = onnx.load(encoder_filename) - decoder_onnx = onnx.load(decoder_filename) - joiner_onnx = onnx.load(joiner_filename) - - encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") - decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") - joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") - - combined_model = onnx.compose.merge_models( - encoder_onnx, decoder_onnx, io_map={} + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, ) - combined_model = onnx.compose.merge_models( - combined_model, joiner_onnx, io_map={} + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, ) - onnx.save(combined_model, all_in_one_filename) - logging.info(f"Saved to {all_in_one_filename}") + logging.info(f"Saved to {decoder_proj_filename}") @torch.no_grad() @@ -628,14 +670,6 @@ def main(): joiner_filename, opset_version=opset_version, ) - - all_in_one_filename = params.exp_dir / "all_in_one.onnx" - export_all_in_one_onnx( - encoder_filename, - decoder_filename, - joiner_filename, - all_in_one_filename, - ) elif params.jit is True: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.script()") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 3da31b7ce..fb9adb44a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -63,6 +63,20 @@ def get_parser(): help="Path to the onnx joiner model", ) + parser.add_argument( + "--onnx-joiner-encoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner encoder projection model", + ) + + parser.add_argument( + "--onnx-joiner-decoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner decoder projection model", + ) + return parser @@ -70,11 +84,13 @@ def test_encoder( model: torch.jit.ScriptModule, encoder_session: ort.InferenceSession, ): - encoder_inputs = encoder_session.get_inputs() - assert encoder_inputs[0].name == "x" - assert encoder_inputs[1].name == "x_lens" - assert encoder_inputs[0].shape == ["N", "T", 80] - assert encoder_inputs[1].shape == ["N"] + inputs = encoder_session.get_inputs() + outputs = encoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", "T", 80] + assert inputs[1].shape == ["N"] for N in [1, 5]: for T in [12, 25]: @@ -84,11 +100,11 @@ def test_encoder( x_lens[0] = T encoder_inputs = { - "x": x.numpy(), - "x_lens": x_lens.numpy(), + input_names[0]: x.numpy(), + input_names[1]: x_lens.numpy(), } encoder_out, encoder_out_lens = encoder_session.run( - ["encoder_out", "encoder_out_lens"], + output_names, encoder_inputs, ) @@ -96,7 +112,9 @@ def test_encoder( encoder_out = torch.from_numpy(encoder_out) assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max() + (encoder_out - torch_encoder_out).abs().max(), + encoder_out.shape, + torch_encoder_out.shape, ) @@ -104,15 +122,18 @@ def test_decoder( model: torch.jit.ScriptModule, decoder_session: ort.InferenceSession, ): - decoder_inputs = decoder_session.get_inputs() - assert decoder_inputs[0].name == "y" - assert decoder_inputs[0].shape == ["N", 2] + inputs = decoder_session.get_inputs() + outputs = decoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", 2] for N in [1, 5, 10]: y = torch.randint(low=1, high=500, size=(10, 2)) - decoder_inputs = {"y": y.numpy()} + decoder_inputs = {input_names[0]: y.numpy()} decoder_out = decoder_session.run( - ["decoder_out"], + output_names, decoder_inputs, )[0] decoder_out = torch.from_numpy(decoder_out) @@ -126,34 +147,92 @@ def test_decoder( def test_joiner( model: torch.jit.ScriptModule, joiner_session: ort.InferenceSession, + joiner_encoder_proj_session: ort.InferenceSession, + joiner_decoder_proj_session: ort.InferenceSession, ): joiner_inputs = joiner_session.get_inputs() - assert joiner_inputs[0].name == "encoder_out" - assert joiner_inputs[0].shape == ["N", 512] + joiner_outputs = joiner_session.get_outputs() + joiner_input_names = [n.name for n in joiner_inputs] + joiner_output_names = [n.name for n in joiner_outputs] - assert joiner_inputs[1].name == "decoder_out" + assert joiner_inputs[0].shape == ["N", 512] assert joiner_inputs[1].shape == ["N", 512] + joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() + encoder_proj_input_name = joiner_encoder_proj_inputs[0].name + + assert joiner_encoder_proj_inputs[0].shape == ["N", 512] + + joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() + encoder_proj_output_name = joiner_encoder_proj_outputs[0].name + + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() + decoder_proj_input_name = joiner_decoder_proj_inputs[0].name + + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] + + joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() + decoder_proj_output_name = joiner_decoder_proj_outputs[0].name + for N in [1, 5, 10]: encoder_out = torch.rand(N, 512) decoder_out = torch.rand(N, 512) + projected_encoder_out = torch.rand(N, 512) + projected_decoder_out = torch.rand(N, 512) + joiner_inputs = { - "encoder_out": encoder_out.numpy(), - "decoder_out": decoder_out.numpy(), + joiner_input_names[0]: projected_encoder_out.numpy(), + joiner_input_names[1]: projected_decoder_out.numpy(), } - joiner_out = joiner_session.run(["logit"], joiner_inputs)[0] + joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] joiner_out = torch.from_numpy(joiner_out) torch_joiner_out = model.joiner( - encoder_out, - decoder_out, - project_input=True, + projected_encoder_out, + projected_decoder_out, + project_input=False, ) assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( (joiner_out - torch_joiner_out).abs().max() ) + # Now test encoder_proj + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } + joiner_encoder_proj_out = joiner_encoder_proj_session.run( + [encoder_proj_output_name], joiner_encoder_proj_inputs + )[0] + joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) + + torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) + assert torch.allclose( + joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) + + # Now test decoder_proj + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } + joiner_decoder_proj_out = joiner_decoder_proj_session.run( + [decoder_proj_output_name], joiner_decoder_proj_inputs + )[0] + joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) + + torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) + assert torch.allclose( + joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) + @torch.no_grad() def main(): @@ -185,7 +264,20 @@ def main(): args.onnx_joiner_filename, sess_options=options, ) - test_joiner(model, joiner_session) + joiner_encoder_proj_session = ort.InferenceSession( + args.onnx_joiner_encoder_proj_filename, + sess_options=options, + ) + joiner_decoder_proj_session = ort.InferenceSession( + args.onnx_joiner_decoder_proj_filename, + sess_options=options, + ) + test_joiner( + model, + joiner_session, + joiner_encoder_proj_session, + joiner_decoder_proj_session, + ) logging.info("Finished checking ONNX models") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py deleted file mode 100755 index b4cf8c94a..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check_all_in_one.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. -""" - -import argparse -import logging -import os - -import onnx -import onnx_graphsurgeon as gs -import onnxruntime -import onnxruntime as ort -import torch - -ort.set_default_logger_severity(3) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-all-in-one-filename", - required=True, - type=str, - help="Path to the onnx all in one model", - ) - - return parser - - -def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, -): - encoder_inputs = encoder_session.get_inputs() - assert encoder_inputs[0].shape == ["N", "T", 80] - assert encoder_inputs[1].shape == ["N"] - encoder_input_names = [i.name for i in encoder_inputs] - encoder_output_names = [i.name for i in encoder_session.get_outputs()] - - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T - - encoder_inputs = { - encoder_input_names[0]: x.numpy(), - encoder_input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - [encoder_output_names[1], encoder_output_names[0]], - encoder_inputs, - ) - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max() - ) - - -def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, -): - decoder_inputs = decoder_session.get_inputs() - assert decoder_inputs[0].shape == ["N", 2] - decoder_input_names = [i.name for i in decoder_inputs] - decoder_output_names = [i.name for i in decoder_session.get_outputs()] - - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {decoder_input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - [decoder_output_names[0]], - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() - ) - - -def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, -): - joiner_inputs = joiner_session.get_inputs() - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] - joiner_input_names = [i.name for i in joiner_inputs] - joiner_output_names = [i.name for i in joiner_session.get_outputs()] - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: encoder_out.numpy(), - joiner_input_names[1]: decoder_out.numpy(), - } - joiner_out = joiner_session.run( - [joiner_output_names[0]], joiner_inputs - )[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - encoder_out, - decoder_out, - project_input=True, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() - ) - - -def extract_sub_model( - onnx_graph: onnx.ModelProto, - input_op_names: list, - output_op_names: list, - non_verbose=False, -): - onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) - graph = gs.import_onnx(onnx_graph) - graph.cleanup().toposort() - - # Extraction of input OP and output OP - graph_node_inputs = [ - graph_nodes - for graph_nodes in graph.nodes - for graph_nodes_input in graph_nodes.inputs - if graph_nodes_input.name in input_op_names - ] - graph_node_outputs = [ - graph_nodes - for graph_nodes in graph.nodes - for graph_nodes_output in graph_nodes.outputs - if graph_nodes_output.name in output_op_names - ] - - # Init graph INPUT/OUTPUT - graph.inputs.clear() - graph.outputs.clear() - - # Update graph INPUT/OUTPUT - graph.inputs = [ - graph_node_input - for graph_node in graph_node_inputs - for graph_node_input in graph_node.inputs - if graph_node_input.shape - ] - graph.outputs = [ - graph_node_output - for graph_node in graph_node_outputs - for graph_node_output in graph_node.outputs - ] - - # Cleanup - graph.cleanup().toposort() - - # Shape Estimation - extracted_graph = None - try: - extracted_graph = onnx.shape_inference.infer_shapes( - gs.export_onnx(graph) - ) - except Exception: - extracted_graph = gs.export_onnx(graph) - if not non_verbose: - print( - "WARNING: " - + "The input shape of the next OP does not match the output shape. " - + "Be sure to open the .onnx file to verify the certainty of the geometry." - ) - return extracted_graph - - -def extract_encoder(onnx_model: onnx.ModelProto): - encoder_ = extract_sub_model( - onnx_model, - ["encoder/x", "encoder/x_lens"], - ["encoder/encoder_out", "encoder/encoder_out_lens"], - False, - ) - onnx.save(encoder_, "tmp_encoder.onnx") - onnx.checker.check_model(encoder_) - sess = onnxruntime.InferenceSession("tmp_encoder.onnx") - os.remove("tmp_encoder.onnx") - return sess - - -def extract_decoder(onnx_model: onnx.ModelProto): - decoder_ = extract_sub_model( - onnx_model, ["decoder/y"], ["decoder/decoder_out"], False - ) - onnx.save(decoder_, "tmp_decoder.onnx") - onnx.checker.check_model(decoder_) - sess = onnxruntime.InferenceSession("tmp_decoder.onnx") - os.remove("tmp_decoder.onnx") - return sess - - -def extract_joiner(onnx_model: onnx.ModelProto): - joiner_ = extract_sub_model( - onnx_model, - ["joiner/encoder_out", "joiner/decoder_out"], - ["joiner/logit"], - False, - ) - onnx.save(joiner_, "tmp_joiner.onnx") - onnx.checker.check_model(joiner_) - sess = onnxruntime.InferenceSession("tmp_joiner.onnx") - os.remove("tmp_joiner.onnx") - return sess - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - model = torch.jit.load(args.jit_filename) - onnx_model = onnx.load(args.onnx_all_in_one_filename) - - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 - - logging.info("Test encoder") - encoder_session = extract_encoder(onnx_model) - test_encoder(model, encoder_session) - - logging.info("Test decoder") - decoder_session = extract_decoder(onnx_model) - test_decoder(model, decoder_session) - - logging.info("Test joiner") - joiner_session = extract_joiner(onnx_model) - test_joiner(model, joiner_session) - logging.info("Finished checking ONNX models") - - -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index ebfae9d5f..ea5d4e674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -27,10 +27,12 @@ You can use the following command to get the exported models: Usage of this script: -./pruned_transducer_stateless3/jit_trace_pretrained.py \ +./pruned_transducer_stateless3/onnx_pretrained.py \ --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ --bpe-model ./data/lang_bpe_500/bpe.model \ /path/to/foo.wav \ /path/to/bar.wav @@ -59,21 +61,35 @@ def get_parser(): "--encoder-model-filename", type=str, required=True, - help="Path to the encoder torchscript model. ", + help="Path to the encoder onnx model. ", ) parser.add_argument( "--decoder-model-filename", type=str, required=True, - help="Path to the decoder torchscript model. ", + help="Path to the decoder onnx model. ", ) parser.add_argument( "--joiner-model-filename", type=str, required=True, - help="Path to the joiner torchscript model. ", + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", ) parser.add_argument( @@ -136,6 +152,8 @@ def read_sound_files( def greedy_search( decoder: ort.InferenceSession, joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, encoder_out: np.ndarray, encoder_out_lens: np.ndarray, context_size: int, @@ -146,6 +164,10 @@ def greedy_search( The decoder model. joiner: The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. encoder_out: A 3-D tensor of shape (N, T, C) encoder_out_lens: @@ -167,6 +189,15 @@ def greedy_search( enforce_sorted=False, ) + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, + )[0] + blank_id = 0 # hard-code to 0 batch_size_list = packed_encoder_out.batch_sizes.tolist() @@ -194,26 +225,31 @@ def greedy_search( decoder_input_nodes[0].name: decoder_input.numpy(), }, )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) offset = 0 for batch_size in batch_size_list: start = offset end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out + current_encoder_out = projected_encoder_out[start:end] # current_encoder_out's shape: (batch_size, encoder_out_dim) offset = end - decoder_out = decoder_out[:batch_size] + projected_decoder_out = projected_decoder_out[:batch_size] logits = joiner.run( [joiner_output_nodes[0].name], { - joiner_input_nodes[0].name: current_encoder_out.numpy(), - joiner_input_nodes[1].name: decoder_out, + joiner_input_nodes[0].name: current_encoder_out, + joiner_input_nodes[1].name: projected_decoder_out.numpy(), }, )[0] - logits = torch.from_numpy(logits) + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) # logits'shape (batch_size, vocab_size) assert logits.ndim == 2, logits.shape @@ -236,6 +272,11 @@ def greedy_search( decoder_input_nodes[0].name: decoder_input.numpy(), }, )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) sorted_ans = [h[context_size:] for h in hyps] ans = [] @@ -271,6 +312,16 @@ def main(): sess_options=session_opts, ) + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + sp = spm.SentencePieceProcessor() sp.load(args.bpe_model) @@ -315,6 +366,8 @@ def main(): hyps = greedy_search( decoder=decoder, joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, context_size=args.context_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index c15d65ded..19b636a23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -271,7 +271,7 @@ def main(): logging.info(f"device: {device}") logging.info("Creating model") - model = get_transducer_model(params) + model = get_transducer_model(params, enable_giga=False) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 428f35796..1e7e808c7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -29,7 +29,10 @@ from typing import List import torch import torch.nn as nn +from lstmp import LSTMP from scaling import ( + ActivationBalancer, + BasicNorm, ScaledConv1d, ScaledConv2d, ScaledEmbedding, @@ -38,6 +41,29 @@ from scaling import ( ) +class NonScaledNorm(nn.Module): + """See BasicNorm for doc""" + + def __init__( + self, + num_channels: int, + eps_exp: float, + channel_dim: int = -1, # CAUTION: see documentation. + ): + super().__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.eps_exp = eps_exp + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not torch.jit.is_tracing(): + assert x.shape[self.channel_dim] == self.num_channels + scales = ( + torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp + ).pow(-0.5) + return x * scales + + def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: """Convert an instance of ScaledLinear to nn.Linear. @@ -174,6 +200,16 @@ def scaled_embedding_to_embedding( return embedding +def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: + assert isinstance(basic_norm, BasicNorm), type(BasicNorm) + norm = NonScaledNorm( + num_channels=basic_norm.num_channels, + eps_exp=basic_norm.eps.data.exp().item(), + channel_dim=basic_norm.channel_dim, + ) + return norm + + def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: """Convert an instance of ScaledLSTM to nn.LSTM. @@ -224,7 +260,11 @@ def get_submodule(model, target): return mod -def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, + is_onnx: bool = False, +): """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, and `nn.Conv2d`. @@ -235,6 +275,9 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): inplace: If True, the input model is modified inplace. If False, the input model is copied and we modify the copied version. + is_onnx: + If True, we are going to export the model to ONNX. In this case, + we will convert nn.LSTM with proj_size to LSTMP. Return: Return a model without scaled layers. """ @@ -256,8 +299,18 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): d[name] = scaled_conv2d_to_conv2d(m) elif isinstance(m, ScaledEmbedding): d[name] = scaled_embedding_to_embedding(m) + elif isinstance(m, BasicNorm): + d[name] = convert_basic_norm(m) elif isinstance(m, ScaledLSTM): - d[name] = scaled_lstm_to_lstm(m) + if is_onnx: + d[name] = LSTMP(scaled_lstm_to_lstm(m)) + # See + # https://github.com/pytorch/pytorch/issues/47887 + # d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m))) + else: + d[name] = scaled_lstm_to_lstm(m) + elif isinstance(m, ActivationBalancer): + d[name] = nn.Identity() for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py new file mode 100755 index 000000000..c55268b14 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file is to test that models can be exported to onnx. +""" +import os + +import onnxruntime as ort +import torch +from conformer import ( + Conformer, + ConformerEncoder, + ConformerEncoderLayer, + Conv2dSubsampling, + RelPositionalEncoding, +) +from scaling_converter import convert_scaled_to_non_scaled + +from icefall.utils import make_pad_mask + +ort.set_default_logger_severity(3) + + +def test_conv2d_subsampling(): + filename = "conv2d_subsampling.onnx" + opset_version = 11 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_embed = Conv2dSubsampling(num_features, d_model) + encoder_embed.eval() + encoder_embed = convert_scaled_to_non_scaled(encoder_embed, inplace=True) + + jit_model = torch.jit.trace(encoder_embed, x) + + torch.onnx.export( + encoder_embed, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + + onnx_y = session.run(["y"], inputs)[0] + + onnx_y = torch.from_numpy(onnx_y) + torch_y = jit_model(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + os.remove(filename) + + +def test_rel_pos(): + filename = "rel_pos.onnx" + + opset_version = 11 + N = 30 + T = 50 + num_features = 80 + d_model = 512 + x = torch.rand(N, T, num_features) + + encoder_pos = RelPositionalEncoding(d_model, dropout_rate=0.1) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + jit_model = torch.jit.trace(encoder_pos, x) + + torch.onnx.export( + encoder_pos, + x, + filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["y", "pos_emb"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "y": {0: "N", 1: "T"}, + "pos_emb": {0: "N", 1: "T"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + assert input_nodes[0].name == "x" + assert input_nodes[0].shape == ["N", "T", num_features] + + inputs = {input_nodes[0].name: x.numpy()} + onnx_y, onnx_pos_emb = session.run(["y", "pos_emb"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_pos_emb = torch.from_numpy(onnx_pos_emb) + + torch_y, torch_pos_emb = jit_model(x) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( + (onnx_pos_emb - torch_pos_emb).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum()) + print(onnx_pos_emb.abs().sum(), torch_pos_emb.abs().sum()) + + os.remove(filename) + + +def test_conformer_encoder_layer(): + filename = "conformer_encoder_layer.onnx" + opset_version = 11 + N = 30 + T = 50 + + d_model = 512 + nhead = 8 + dim_feedforward = 2048 + dropout = 0.1 + layer_dropout = 0.075 + cnn_module_kernel = 31 + causal = False + + x = torch.rand(N, T, d_model) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x, pos_emb = encoder_pos(x) + x = x.permute(1, 0, 2) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + causal, + ) + encoder_layer.eval() + encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) + + jit_model = torch.jit.trace( + encoder_layer, (x, pos_emb, src_key_padding_mask) + ) + + torch.onnx.export( + encoder_layer, + (x, pos_emb, src_key_padding_mask), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb", "src_key_padding_mask"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + input_nodes[2].name: src_key_padding_mask.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = jit_model(x, pos_emb, src_key_padding_mask) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_conformer_encoder(): + filename = "conformer_encoder.onnx" + + opset_version = 11 + N = 3 + T = 15 + + d_model = 512 + nhead = 8 + dim_feedforward = 2048 + dropout = 0.1 + layer_dropout = 0.075 + cnn_module_kernel = 31 + causal = False + num_encoder_layers = 12 + + x = torch.rand(N, T, d_model) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + src_key_padding_mask = make_pad_mask(x_lens) + + encoder_pos = RelPositionalEncoding(d_model, dropout) + encoder_pos.eval() + encoder_pos = convert_scaled_to_non_scaled(encoder_pos, inplace=True) + + x, pos_emb = encoder_pos(x) + x = x.permute(1, 0, 2) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + layer_dropout, + cnn_module_kernel, + causal, + ) + encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + encoder.eval() + encoder = convert_scaled_to_non_scaled(encoder, inplace=True) + + jit_model = torch.jit.trace(encoder, (x, pos_emb, src_key_padding_mask)) + + torch.onnx.export( + encoder, + (x, pos_emb, src_key_padding_mask), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "pos_emb", "src_key_padding_mask"], + output_names=["y"], + dynamic_axes={ + "x": {0: "T", 1: "N"}, + "pos_emb": {0: "N", 1: "T"}, + "src_key_padding_mask": {0: "N", 1: "T"}, + "y": {0: "T", 1: "N"}, + }, + ) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: pos_emb.numpy(), + input_nodes[2].name: src_key_padding_mask.numpy(), + } + onnx_y = session.run(["y"], inputs)[0] + onnx_y = torch.from_numpy(onnx_y) + + torch_y = jit_model(x, pos_emb, src_key_padding_mask) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + + os.remove(filename) + + +def test_conformer(): + filename = "conformer.onnx" + opset_version = 11 + N = 3 + T = 15 + num_features = 80 + x = torch.rand(N, T, num_features) + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) + + conformer = Conformer(num_features=num_features) + conformer.eval() + conformer = convert_scaled_to_non_scaled(conformer, inplace=True) + + jit_model = torch.jit.trace(conformer, (x, x_lens)) + torch.onnx.export( + conformer, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["y", "y_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "y": {0: "N", 1: "T"}, + "y_lens": {0: "N"}, + }, + ) + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + session = ort.InferenceSession( + filename, + sess_options=options, + ) + + input_nodes = session.get_inputs() + inputs = { + input_nodes[0].name: x.numpy(), + input_nodes[1].name: x_lens.numpy(), + } + onnx_y, onnx_y_lens = session.run(["y", "y_lens"], inputs) + onnx_y = torch.from_numpy(onnx_y) + onnx_y_lens = torch.from_numpy(onnx_y_lens) + + torch_y, torch_y_lens = jit_model(x, x_lens) + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) + + assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( + (onnx_y_lens - torch_y_lens).abs().max() + ) + print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) + print(onnx_y_lens, torch_y_lens) + + os.remove(filename) + + +@torch.no_grad() +def main(): + test_conv2d_subsampling() + test_rel_pos() + test_conformer_encoder_layer() + test_conformer_encoder() + test_conformer() + + +if __name__ == "__main__": + torch.manual_seed(20221011) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 8431492e6..b75a72a15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -538,7 +538,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -612,7 +612,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 9d63cb123..427b06294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -527,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -1436,7 +1438,11 @@ class ConvolutionModule(nn.Module): ) def forward( - self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0 + self, + x: Tensor, + cache: Optional[Tensor] = None, + right_context: int = 0, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. @@ -1448,6 +1454,7 @@ class ConvolutionModule(nn.Module): How many future frames the attention can see in current chunk. Note: It's not that each individual frame has `right_context` frames of right context, some have more. + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -1466,6 +1473,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) if self.causal and self.lorder > 0: if cache is None: # Make depthwise_conv causal by diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 32bbd16f7..96103500b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -524,7 +524,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -598,7 +598,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 90f2c8b1d..53788b3f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -264,7 +264,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = src + self.dropout( + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + ) # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -927,11 +929,16 @@ class ConvolutionModule(nn.Module): initial_scale=0.25, ) - def forward(self, x: Tensor) -> Tensor: + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -947,6 +954,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) x = self.depthwise_conv(x) x = self.deriv_balancer2(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 2f69ba401..74df04006 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -350,7 +350,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -420,7 +420,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index f1aacb5e7..7d0cd0bf3 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -274,7 +274,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -345,7 +345,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 83e924256..24f243974 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -261,7 +261,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -327,7 +327,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[str, Tuple[List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 43debe643..604235e2a 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -258,7 +258,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -324,7 +324,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 2bf633201..cde52c9fc 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -514,7 +514,9 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_conv(src) - src, _ = self.conv_module(src) + src, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = residual + self.dropout(src) if not self.normalize_before: @@ -1383,11 +1385,18 @@ class ConvolutionModule(nn.Module): x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. Args: x: Input tensor (#time, batch, channels). + cache: The cache of depthwise_conv, only used in real streaming + decoding. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: Tensor: Output tensor (#time, batch, channels). @@ -1401,6 +1410,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) if self.causal and self.lorder > 0: if cache is None: # Make depthwise_conv causal by diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 19a685090..74bba9cad 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -313,7 +313,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -383,7 +383,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index f48ce82f4..ac2807241 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -313,7 +313,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -383,7 +383,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 2bb6df5d6..d596e05cb 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -314,7 +314,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -384,7 +384,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 722cd8dbd..c39bd0530 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -328,7 +328,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -398,7 +398,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() test_set_cers = dict() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 45f702163..b624913f5 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -413,7 +413,7 @@ def decode_dataset( lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, sp: spm.SentencePieceProcessor = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -515,7 +515,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index bb352baa7..2b294e601 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -313,7 +313,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -383,7 +383,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py index b6fc9a926..8d5cdf683 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -658,18 +658,8 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and 17 seconds return 1.0 <= c.duration <= 17.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = tedlium.train_dataloaders(train_cuts) valid_cuts = tedlium.dev_cuts() valid_dl = tedlium.valid_dataloaders(valid_cuts) diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index c1aa2c366..d3e9e55e7 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -291,7 +291,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -357,7 +357,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index dda6108c5..09cbf4a00 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -627,18 +627,8 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and 17 seconds return 1.0 <= c.duration <= 17.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = tedlium.train_dataloaders(train_cuts) valid_cuts = tedlium.dev_cuts() valid_dl = tedlium.valid_dataloaders(valid_cuts) diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py index 84d9f7f1b..4f2aa2340 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py @@ -274,7 +274,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -345,7 +345,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index 7672a2e1d..5e7300cf2 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -273,7 +273,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -344,7 +344,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index bbd8680b2..f0c9bebec 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -455,7 +455,7 @@ def decode_dataset( lexicon: Lexicon, graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -524,7 +524,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py old mode 100644 new mode 100755 index 345792a3c..933642a0f --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -18,6 +19,64 @@ # to a single one using model averaging. """ Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --epoch 10 \ + --avg 2 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Please refer to +https://k2-fsa.github.io/sherpa/python/offline_asr/conformer/index.html +for how to use `cpu_jit.pt` for speech recognition. + +It will also generate 3 other files: `encoder_jit_script.pt`, +`decoder_jit_script.pt`, and `joiner_jit_script.pt`. Check ./jit_pretrained.py +for how to use them. + +(2) Export to torchscript model using torch.jit.trace() + +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --epoch 10 \ + --avg 2 \ + --jit-trace 1 + +It will generate the following files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + +Check ./jit_pretrained.py for usage. + +(3) Export to ONNX format + +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --epoch 10 \ + --avg 2 \ + --onnx 1 + +Refer to ./onnx_check.py and ./onnx_pretrained.py +for usage. + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. + +(4) Export `model.state_dict()` + ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ --lang-dir data/lang_char \ @@ -35,10 +94,13 @@ you can do: cd /path/to/egs/wenetspeech/ASR ./pruned_transducer_stateless2/decode.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 10 \ - --avg 2 \ + --epoch 9999 \ + --avg 1 \ --max-duration 100 \ --lang-dir data/lang_char + +You can find pretrained models at +https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp """ import argparse @@ -46,6 +108,8 @@ import logging from pathlib import Path import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -96,6 +160,44 @@ def get_parser(): type=str2bool, default=False, help="""True to save a model after applying torch.jit.script. + It will generate 4 files: + - encoder_jit_script.pt + - decoder_jit_script.pt + - joiner_jit_script.pt + - cpu_jit.pt (which combines the above 3 files) + + Check ./jit_pretrained.py for how to use xxx_jit_script.pt + """, + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, ) @@ -110,6 +212,332 @@ def get_parser(): return parser +def export_encoder_model_jit_script( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.script() + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(encoder_model) + script_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_script( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.script() + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(decoder_model) + script_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_script( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(joiner_model) + script_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + traced_model = torch.jit.trace(encoder_model, (x, x_lens)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + warmup = 1.0 + torch.onnx.export( + encoder_model, + (x, x_lens, warmup), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "warmup"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "projected_encoder_out", + "projected_decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "projected_encoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) @@ -147,22 +575,66 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) - model.eval() - model.to("cpu") model.eval() - if params.jit: + if params.onnx is True: + convert_scaled_to_non_scaled(model, inplace=True) + opset_version = 11 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.jit: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script") # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") + + # Also export encoder/decoder/joiner separately + encoder_filename = params.exp_dir / "encoder_jit_script.pt" + export_encoder_model_jit_script(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_script.pt" + export_decoder_model_jit_script(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_script.pt" + export_joiner_model_jit_script(model.joiner, joiner_filename) + elif params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) else: logging.info("Not using torch.jit.script") # Save it using a format so that it can be loaded diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py new file mode 100755 index 000000000..e5cc47bfe --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --tokens data/lang_char/tokens.txt \ + --epoch 10 \ + --avg 2 \ + --jit-trace 1 + +or + +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --tokens data/lang_char/tokens.txt \ + --epoch 10 \ + --avg 2 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless2/jit_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_trace.pt \ + --tokens data/lang_char/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav + +or + +./pruned_transducer_stateless2/jit_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_script.pt \ + --decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_script.pt \ + --joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_script.pt \ + --tokens data/lang_char/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can find pretrained models at +https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + symbol_table = k2.SymbolTable.from_file(args.tokens) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = "".join([symbol_table[i] for i in hyp]) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py new file mode 100755 index 000000000..91877ec46 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. + +Usage: + +./pruned_transducer_stateless2/onnx_check.py \ + --jit-filename ./t/cpu_jit.pt \ + --onnx-encoder-filename ./t/encoder.onnx \ + --onnx-decoder-filename ./t/decoder.onnx \ + --onnx-joiner-filename ./t/joiner.onnx \ + --onnx-joiner-encoder-proj-filename ./t/joiner_encoder_proj.onnx \ + --onnx-joiner-decoder-proj-filename ./t/joiner_decoder_proj.onnx + +You can generate cpu_jit.pt, encoder.onnx, decoder.onnx, and other +xxx.onnx files using ./export.py + +We provide pretrained models at: +https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model exported by torch.jit.script", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + parser.add_argument( + "--onnx-joiner-encoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner encoder projection model", + ) + + parser.add_argument( + "--onnx-joiner-decoder-proj-filename", + required=True, + type=str, + help="Path to the onnx joiner decoder projection model", + ) + + return parser + + +def test_encoder( + model: torch.jit.ScriptModule, + encoder_session: ort.InferenceSession, +): + inputs = encoder_session.get_inputs() + outputs = encoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", "T", 80] + assert inputs[1].shape == ["N"] + + for N in [1, 5]: + for T in [12, 25]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T + + encoder_inputs = { + input_names[0]: x.numpy(), + input_names[1]: x_lens.numpy(), + } + encoder_out, encoder_out_lens = encoder_session.run( + output_names, + encoder_inputs, + ) + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max(), + encoder_out.shape, + torch_encoder_out.shape, + ) + + +def test_decoder( + model: torch.jit.ScriptModule, + decoder_session: ort.InferenceSession, +): + inputs = decoder_session.get_inputs() + outputs = decoder_session.get_outputs() + input_names = [n.name for n in inputs] + output_names = [n.name for n in outputs] + + assert inputs[0].shape == ["N", 2] + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) + + decoder_inputs = {input_names[0]: y.numpy()} + decoder_out = decoder_session.run( + output_names, + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) + + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, + joiner_encoder_proj_session: ort.InferenceSession, + joiner_decoder_proj_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + joiner_outputs = joiner_session.get_outputs() + joiner_input_names = [n.name for n in joiner_inputs] + joiner_output_names = [n.name for n in joiner_outputs] + + assert joiner_inputs[0].shape == ["N", 512] + assert joiner_inputs[1].shape == ["N", 512] + + joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() + encoder_proj_input_name = joiner_encoder_proj_inputs[0].name + + assert joiner_encoder_proj_inputs[0].shape == ["N", 512] + + joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() + encoder_proj_output_name = joiner_encoder_proj_outputs[0].name + + joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() + decoder_proj_input_name = joiner_decoder_proj_inputs[0].name + + assert joiner_decoder_proj_inputs[0].shape == ["N", 512] + + joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() + decoder_proj_output_name = joiner_decoder_proj_outputs[0].name + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 512) + decoder_out = torch.rand(N, 512) + + projected_encoder_out = torch.rand(N, 512) + projected_decoder_out = torch.rand(N, 512) + + joiner_inputs = { + joiner_input_names[0]: projected_encoder_out.numpy(), + joiner_input_names[1]: projected_decoder_out.numpy(), + } + joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + projected_encoder_out, + projected_decoder_out, + project_input=False, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) + + # Now test encoder_proj + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } + joiner_encoder_proj_out = joiner_encoder_proj_session.run( + [encoder_proj_output_name], joiner_encoder_proj_inputs + )[0] + joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) + + torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) + assert torch.allclose( + joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) + + # Now test decoder_proj + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } + joiner_decoder_proj_out = joiner_decoder_proj_session.run( + [decoder_proj_output_name], joiner_decoder_proj_inputs + )[0] + joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) + + torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) + assert torch.allclose( + joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + model = torch.jit.load(args.jit_filename) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + logging.info("Test encoder") + encoder_session = ort.InferenceSession( + args.onnx_encoder_filename, + sess_options=options, + ) + test_encoder(model, encoder_session) + + logging.info("Test decoder") + decoder_session = ort.InferenceSession( + args.onnx_decoder_filename, + sess_options=options, + ) + test_decoder(model, decoder_session) + + logging.info("Test joiner") + joiner_session = ort.InferenceSession( + args.onnx_joiner_filename, + sess_options=options, + ) + joiner_encoder_proj_session = ort.InferenceSession( + args.onnx_joiner_encoder_proj_filename, + sess_options=options, + ) + joiner_decoder_proj_session = ort.InferenceSession( + args.onnx_joiner_decoder_proj_filename, + sess_options=options, + ) + test_joiner( + model, + joiner_session, + joiner_encoder_proj_session, + joiner_decoder_proj_session, + ) + logging.info("Finished checking ONNX models") + + +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py new file mode 100755 index 000000000..132517352 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --lang-dir data/lang_char \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./pruned_transducer_stateless3/onnx_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \ + --tokens data/lang_char/tokens.txt \ + /path/to/foo.wav \ + /path/to/bar.wav + +We provide pretrained models at: +https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + joiner_encoder_proj: ort.InferenceSession, + joiner_decoder_proj: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + joiner_encoder_proj: + The joiner encoder projection model. + joiner_decoder_proj: + The joiner decoder projection model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + projected_encoder_out = joiner_encoder_proj.run( + [joiner_encoder_proj.get_outputs()[0].name], + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, + )[0] + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = projected_encoder_out[start:end] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + projected_decoder_out = projected_decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: current_encoder_out, + joiner_input_nodes[1].name: projected_decoder_out.numpy(), + }, + )[0] + logits = torch.from_numpy(logits).squeeze(1).squeeze(1) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + projected_decoder_out = joiner_decoder_proj.run( + [joiner_decoder_proj.get_outputs()[0].name], + {joiner_decoder_proj.get_inputs()[0].name: decoder_out}, + )[0] + projected_decoder_out = torch.from_numpy(projected_decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=session_opts, + ) + + joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=session_opts, + ) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + joiner_encoder_proj=joiner_encoder_proj, + joiner_decoder_proj=joiner_decoder_proj, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + symbol_table = k2.SymbolTable.from_file(args.tokens) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = "".join([symbol_table[i] for i in hyp]) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py old mode 100644 new mode 100755 index 27ffc3bfc..9a549efd9 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -21,7 +21,7 @@ Usage: ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --lang-dir ./data/lang_char \ - --method greedy_search \ + --decoding-method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ /path/to/bar.wav @@ -29,7 +29,7 @@ Usage: ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --lang-dir ./data/lang_char \ - --method modified_beam_search \ + --decoding-method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage: ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --lang-dir ./data/lang_char \ - --method fast_beam_search \ + --decoding-method fast_beam_search \ --beam 4 \ --max-contexts 4 \ --max-states 8 \ @@ -116,7 +116,7 @@ def get_parser(): parser.add_argument( "--sample-rate", type=int, - default=48000, + default=16000, help="The sample rate of the input sound file", ) @@ -124,7 +124,8 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --method is beam_search and modified_beam_search ", + help="""Used only when --decoding-method is beam_search + and modified_beam_search """, ) parser.add_argument( @@ -166,7 +167,7 @@ def get_parser(): type=int, default=1, help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. + --decoding-method is greedy_search. """, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py index 78baa2b78..dd27c17f0 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py @@ -520,7 +520,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -1392,6 +1394,7 @@ class ConvolutionModule(nn.Module): x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. Args: @@ -1402,6 +1405,7 @@ class ConvolutionModule(nn.Module): How many future frames the attention can see in current chunk. Note: It's not that each individual frame has `right_context` frames of right context, some have more. + src_key_padding_mask: the mask for the src keys per batch (optional). Returns: If cache is None return the output tensor (#time, batch, channels). If cache is not None, return a tuple of Tensor, the first one is @@ -1418,6 +1422,8 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) if self.causal and self.lorder > 0: if cache is None: # Make depthwise_conv causal by diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index c36df2458..344e31283 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -425,7 +425,7 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -493,7 +493,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 79adcb14e..9d4ab4b61 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -147,7 +147,7 @@ def decode_dataset( model: nn.Module, HLG: k2.Fsa, word_table: k2.SymbolTable, -) -> List[Tuple[List[int], List[int]]]: +) -> List[Tuple[str, List[str], List[str]]]: """Decode dataset. Args: @@ -210,7 +210,7 @@ def decode_dataset( def save_results( exp_dir: Path, test_set_name: str, - results: List[Tuple[List[int], List[int]]], + results: List[Tuple[str, List[str], List[str]]], ) -> None: """Save results to `exp_dir`. Args: diff --git a/requirements.txt b/requirements.txt index 2e72d2eb6..d5931e49a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ multi_quantization onnx onnxruntime --extra-index-url https://pypi.ngc.nvidia.com -onnx_graphsurgeon +dill